Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
ab72a129
Commit
ab72a129
authored
Aug 04, 2022
by
Tim Dettmers
Browse files
Added pre/post device call for extract outliers.
parent
cc5b3238
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
bitsandbytes/functional.py
bitsandbytes/functional.py
+4
-2
No files found.
bitsandbytes/functional.py
View file @
ab72a129
...
@@ -1198,6 +1198,7 @@ def get_special_format_str():
...
@@ -1198,6 +1198,7 @@ def get_special_format_str():
def
transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
def
transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
prev_device
=
pre_call
(
A
.
device
)
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
...
@@ -1214,7 +1215,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
...
@@ -1214,7 +1215,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA
=
get_ptr
(
A
)
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
is_on_gpu
([
A
,
out
])
is_on_gpu
([
A
,
out
])
prev_device
=
pre_call
(
A
.
device
)
if
to_order
==
'col32'
:
if
to_order
==
'col32'
:
if
transpose
:
if
transpose
:
lib
.
ctransform_row2col32T
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_row2col32T
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
...
@@ -1237,8 +1237,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
...
@@ -1237,8 +1237,8 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
lib
.
ctransform_ampere2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_ampere2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
else
:
raise
NotImplementedError
(
f
'Transform function not implemented: From
{
from_order
}
to
{
to_order
}
'
)
raise
NotImplementedError
(
f
'Transform function not implemented: From
{
from_order
}
to
{
to_order
}
'
)
post_call
(
prev_device
)
post_call
(
prev_device
)
return
out
,
new_state
return
out
,
new_state
...
@@ -1451,10 +1451,12 @@ def extract_outliers(A, SA, idx):
...
@@ -1451,10 +1451,12 @@ def extract_outliers(A, SA, idx):
ptrIdx
=
get_ptr
(
idx
)
ptrIdx
=
get_ptr
(
idx
)
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
prev_device
=
pre_call
(
A
.
device
)
if
formatA
==
'col_turing'
:
if
formatA
==
'col_turing'
:
lib
.
cextractOutliers_turing
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
lib
.
cextractOutliers_turing
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
elif
formatA
==
'col_ampere'
:
elif
formatA
==
'col_ampere'
:
lib
.
cextractOutliers_ampere
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
lib
.
cextractOutliers_ampere
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
post_call
(
prev_device
)
return
out
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment