Unverified Commit fceff07a authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear (#1488)



* fix fuse_wgrad_accumulation for GroupedLinear
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix fuse_wgrad_accumulation for GroupedLinear
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update tests
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 56c0c070
...@@ -1400,7 +1400,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): ...@@ -1400,7 +1400,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _test_grouped_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
reset_rng_states() reset_rng_states()
if fp8: if fp8:
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
...@@ -1447,6 +1449,10 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f ...@@ -1447,6 +1449,10 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
outputs = [out, inp_hidden_states.grad] outputs = [out, inp_hidden_states.grad]
for p in block.parameters(): for p in block.parameters():
if p.requires_grad: if p.requires_grad:
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad) outputs.append(p.grad)
return outputs return outputs
...@@ -1458,8 +1464,17 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f ...@@ -1458,8 +1464,17 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_grouped_linear_accuracy( def test_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
parallel_mode=None,
): ):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -1481,6 +1496,7 @@ def test_grouped_linear_accuracy( ...@@ -1481,6 +1496,7 @@ def test_grouped_linear_accuracy(
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
...@@ -1491,6 +1507,7 @@ def test_grouped_linear_accuracy( ...@@ -1491,6 +1507,7 @@ def test_grouped_linear_accuracy(
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval() ).eval()
for _ in range(num_gemms) for _ in range(num_gemms)
] ]
...@@ -1501,12 +1518,16 @@ def test_grouped_linear_accuracy( ...@@ -1501,12 +1518,16 @@ def test_grouped_linear_accuracy(
for i in range(num_gemms): for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy( outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8 sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
) )
outputs = _test_grouped_linear_accuracy( outputs = _test_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
) )
# Shoule be bit-wise match # Shoule be bit-wise match
...@@ -1527,6 +1548,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): ...@@ -1527,6 +1548,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
recipe=recipe, recipe=recipe,
fp8_model_params=True, fp8_model_params=True,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
fuse_wgrad_accumulation=True,
) )
...@@ -1541,6 +1563,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe): ...@@ -1541,6 +1563,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
fp8=True, fp8=True,
recipe=recipe, recipe=recipe,
fp8_model_params=True, fp8_model_params=True,
fuse_wgrad_accumulation=True,
) )
......
...@@ -178,7 +178,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -178,7 +178,6 @@ class _GroupedLinear(torch.autograd.Function):
if is_grad_enabled: if is_grad_enabled:
saved_inputs, saved_weights = [], []
ctx.weights_shape_1 = weights[0].shape[1] ctx.weights_shape_1 = weights[0].shape[1]
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
...@@ -186,9 +185,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -186,9 +185,11 @@ class _GroupedLinear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.weights_requires_grad = weights[0].requires_grad ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)]
else:
ctx.main_grads = [None] * num_gemms
ctx.device = device ctx.device = device
ctx.saved_inputs = saved_inputs
ctx.saved_weights = saved_weights
ctx.grad_output_quantizers = grad_output_quantizers ctx.grad_output_quantizers = grad_output_quantizers
ctx.m_splits = m_splits ctx.m_splits = m_splits
ctx.num_gemms = num_gemms ctx.num_gemms = num_gemms
...@@ -220,7 +221,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -220,7 +221,7 @@ class _GroupedLinear(torch.autograd.Function):
inputmats = saved_tensors[:N] inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N] weights = saved_tensors[N : 2 * N]
biases = saved_tensors[2 * N : 3 * N] biases = saved_tensors[2 * N : 3 * N]
main_grads = saved_tensors[3 * N :] main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
for i in ctx.num_gemms: for i in ctx.num_gemms:
...@@ -281,7 +282,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -281,7 +282,7 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.weights_requires_grad: if ctx.weights_requires_grad:
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
wgrad_list = [w.main_grad for w in weights] wgrad_list = main_grads
else: else:
wgrad_list = [ wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment