Unverified Commit fceff07a authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear (#1488)



* fix fuse_wgrad_accumulation for GroupedLinear
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix fuse_wgrad_accumulation for GroupedLinear
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update tests
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 56c0c070
......@@ -1400,7 +1400,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _test_grouped_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
......@@ -1447,7 +1449,11 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs
......@@ -1458,8 +1464,17 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
......@@ -1481,6 +1496,7 @@ def test_grouped_linear_accuracy(
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
sequential_linear = torch.nn.ModuleList(
[
......@@ -1491,6 +1507,7 @@ def test_grouped_linear_accuracy(
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
......@@ -1501,12 +1518,16 @@ def test_grouped_linear_accuracy(
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
outputs = _test_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
# Shoule be bit-wise match
......@@ -1527,6 +1548,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
recipe=recipe,
fp8_model_params=True,
parallel_mode=parallel_mode,
fuse_wgrad_accumulation=True,
)
......@@ -1541,6 +1563,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
fp8=True,
recipe=recipe,
fp8_model_params=True,
fuse_wgrad_accumulation=True,
)
......
......@@ -178,7 +178,6 @@ class _GroupedLinear(torch.autograd.Function):
if is_grad_enabled:
saved_inputs, saved_weights = [], []
ctx.weights_shape_1 = weights[0].shape[1]
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
......@@ -186,9 +185,11 @@ class _GroupedLinear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)]
else:
ctx.main_grads = [None] * num_gemms
ctx.device = device
ctx.saved_inputs = saved_inputs
ctx.saved_weights = saved_weights
ctx.grad_output_quantizers = grad_output_quantizers
ctx.m_splits = m_splits
ctx.num_gemms = num_gemms
......@@ -220,7 +221,7 @@ class _GroupedLinear(torch.autograd.Function):
inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N]
biases = saved_tensors[2 * N : 3 * N]
main_grads = saved_tensors[3 * N :]
main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
for i in ctx.num_gemms:
......@@ -281,31 +282,31 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.weights_requires_grad:
if ctx.fuse_wgrad_accumulation:
wgrad_list = [w.main_grad for w in weights]
wgrad_list = main_grads
else:
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
for w in weights
]
# WGRAD
_, grad_biases_, _ = general_grouped_gemm(
inputmats,
grad_output,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
accumulate=accumulate_wgrad_into_param_main_grad,
)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# WGRAD
_, grad_biases_, _ = general_grouped_gemm(
inputmats,
grad_output,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
accumulate=accumulate_wgrad_into_param_main_grad,
)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor
clear_tensor_data(*inputmats)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment