Commit 6101a8fb authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added pre and post device call to transform.

parent 320eacb4
...@@ -1214,6 +1214,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ...@@ -1214,6 +1214,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA = get_ptr(A) ptrA = get_ptr(A)
ptrOut = get_ptr(out) ptrOut = get_ptr(out)
is_on_gpu([A, out]) is_on_gpu([A, out])
prev_device = pre_call(A.device)
if to_order == 'col32': if to_order == 'col32':
if transpose: if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
...@@ -1236,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No ...@@ -1236,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else: else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
post_call(prev_device)
return out, new_state return out, new_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment