Commit 2bb5c00b authored by Tim Dettmers's avatar Tim Dettmers
Browse files

Added pre/post call to all lib calls. Fixes #120

parent 29ab3a6b
...@@ -770,6 +770,8 @@ def optimizer_update_32bit( ...@@ -770,6 +770,8 @@ def optimizer_update_32bit(
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
) )
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec])
if g.dtype == torch.float32 and state1.dtype == torch.float32: if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0]( str2optimizer32bit[optimizer_name][0](
get_ptr(g), get_ptr(g),
...@@ -812,6 +814,7 @@ def optimizer_update_32bit( ...@@ -812,6 +814,7 @@ def optimizer_update_32bit(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) )
post_call(prev_device)
def optimizer_update_8bit( def optimizer_update_8bit(
...@@ -890,6 +893,8 @@ def optimizer_update_8bit( ...@@ -890,6 +893,8 @@ def optimizer_update_8bit(
if max_unorm > 0.0: if max_unorm > 0.0:
param_norm = torch.norm(p.data.float()) param_norm = torch.norm(p.data.float())
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][0]( str2optimizer8bit[optimizer_name][0](
get_ptr(p), get_ptr(p),
...@@ -942,6 +947,7 @@ def optimizer_update_8bit( ...@@ -942,6 +947,7 @@ def optimizer_update_8bit(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) )
post_call(prev_device)
def optimizer_update_8bit_blockwise( def optimizer_update_8bit_blockwise(
...@@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise( ...@@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise(
skip_zeros=False, skip_zeros=False,
) -> None: ) -> None:
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][0]( str2optimizer8bit_blockwise[optimizer_name][0](
get_ptr(p), get_ptr(p),
...@@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise( ...@@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) )
post_call(prev_device)
def percentile_clipping( def percentile_clipping(
...@@ -1023,6 +1032,7 @@ def percentile_clipping( ...@@ -1023,6 +1032,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms). The current optimiation steps (number of past gradient norms).
""" """
prev_device = pre_call(grad.device)
is_on_gpu([grad, gnorm_vec]) is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32: if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32( lib.cpercentile_clipping_g32(
...@@ -1040,6 +1050,7 @@ def percentile_clipping( ...@@ -1040,6 +1050,7 @@ def percentile_clipping(
) )
else: else:
raise ValueError(f"Gradient type {grad.dtype} not supported!") raise ValueError(f"Gradient type {grad.dtype} not supported!")
post_call(prev_device)
current_gnorm = torch.sqrt(gnorm_vec[step % 100]) current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec) vals, idx = torch.sort(gnorm_vec)
...@@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
) )
nnz = cooA.nnz nnz = cooA.nnz
prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz assert cooA.values.numel() == nnz
...@@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB, ccolsB,
) )
# else: assertion error # else: assertion error
post_call(prev_device)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment