Unverified Commit fb749619 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Removed the unused options from GroupedLinear docs and fixed the bug with offsets (#1220)



* Removing the unused options from GroupedLinear docs and fixing the bug
with offsets
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* offsets -> fp8_meta_offsets
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 46075b98
...@@ -44,18 +44,6 @@ from ..export import is_in_onnx_export_mode ...@@ -44,18 +44,6 @@ from ..export import is_in_onnx_export_mode
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
"""
The offset for fp8_meta_index.
_GEMM_INPUT = 0
_GEMM_WEIGHT = num_gemms
_GEMM_OUTPUT = 2 * num_gemms
Must be properly set in GroupedLinear's initialization.
"""
_GEMM_INPUT = 0
_GEMM_WEIGHT = 0
_GEMM_OUTPUT = 0
_GRAD_OUTPUT = 0
class _GroupedLinear(torch.autograd.Function): class _GroupedLinear(torch.autograd.Function):
"""GroupedLinear semi-top level module """GroupedLinear semi-top level module
...@@ -74,12 +62,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -74,12 +62,9 @@ class _GroupedLinear(torch.autograd.Function):
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool, cpu_offloading: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool, sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], fp8_meta_offsets: Dict[str, int],
is_grad_enabled: bool, is_grad_enabled: bool,
weights_fp8: List[Union[Float8Tensor, None]], weights_fp8: List[Union[Float8Tensor, None]],
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
...@@ -103,7 +88,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -103,7 +88,6 @@ class _GroupedLinear(torch.autograd.Function):
inputmats_t = [] inputmats_t = []
inputmat_scale_inv = None inputmat_scale_inv = None
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device)
...@@ -114,7 +98,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -114,7 +98,9 @@ class _GroupedLinear(torch.autograd.Function):
and not sequence_parallel and not sequence_parallel
): ):
# FP8 input for forward, FP8 input transpose for backward wgrad # FP8 input for forward, FP8 input transpose for backward wgrad
indices = list(range(_GEMM_INPUT, _GEMM_INPUT + num_gemms)) indices = list(
range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms)
)
inputmats, inputmats_t = fp8_multi_cast_transpose_fused( inputmats, inputmats_t = fp8_multi_cast_transpose_fused(
inputmats_no_fp8, inputmats_no_fp8,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
...@@ -130,7 +116,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -130,7 +116,7 @@ class _GroupedLinear(torch.autograd.Function):
cast_to_fp8( cast_to_fp8(
inputmats_no_fp8[i], inputmats_no_fp8[i],
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
_GEMM_INPUT + i, fp8_meta_offsets["input"] + i,
fp8_dtype_forward, fp8_dtype_forward,
scale_inv=inputmat_scale_inv, scale_inv=inputmat_scale_inv,
) )
...@@ -194,14 +180,14 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -194,14 +180,14 @@ class _GroupedLinear(torch.autograd.Function):
for i in range(num_gemms): for i in range(num_gemms):
# amax of input # amax of input
amin, amax = inputmats[i].aminmax() amin, amax = inputmats[i].aminmax()
fp8_meta["scaling_fwd"].amax_history[0][_GEMM_INPUT + i] = torch.max( fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = (
-amin, amax torch.max(-amin, amax).float()
).float() )
# amax of weight # amax of weight
amin, amax = weights[i].aminmax() amin, amax = weights[i].aminmax()
fp8_meta["scaling_fwd"].amax_history[0][_GEMM_WEIGHT + i] = torch.max( fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = (
-amin, amax torch.max(-amin, amax).float()
).float() )
out = torch.empty( out = torch.empty(
[sum(m_splits), weights[0].size(0)], [sum(m_splits), weights[0].size(0)],
...@@ -266,11 +252,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -266,11 +252,8 @@ class _GroupedLinear(torch.autograd.Function):
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.fp8_meta_offsets = fp8_meta_offsets
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
...@@ -300,7 +283,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -300,7 +283,6 @@ class _GroupedLinear(torch.autograd.Function):
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
# preprocess grad_output # preprocess grad_output
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
grad_output_mats = torch.split( grad_output_mats = torch.split(
...@@ -318,13 +300,18 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -318,13 +300,18 @@ class _GroupedLinear(torch.autograd.Function):
fp8_cast_transpose_bgrad_fused( fp8_cast_transpose_bgrad_fused(
grad_output_mats[i], grad_output_mats[i],
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i, ctx.fp8_meta_offsets["grad_output"] + i,
fp8_dtype_backward, fp8_dtype_backward,
) )
) )
else: else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
indices = list(range(_GRAD_OUTPUT, _GRAD_OUTPUT + ctx.num_gemms)) indices = list(
range(
ctx.fp8_meta_offsets["grad_output"],
ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms,
)
)
grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused(
grad_output_mats, grad_output_mats,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
...@@ -338,7 +325,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -338,7 +325,7 @@ class _GroupedLinear(torch.autograd.Function):
grad_output_c[i] = cast_to_fp8( grad_output_c[i] = cast_to_fp8(
grad_output_mats[i], grad_output_mats[i],
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
_GRAD_OUTPUT + i, ctx.fp8_meta_offsets["grad_output"] + i,
fp8_dtype_backward, fp8_dtype_backward,
) )
...@@ -363,7 +350,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -363,7 +350,7 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8[0]._fp8_dtype, weights_fp8[0]._fp8_dtype,
grad_output_c, grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT, ctx.fp8_meta_offsets["grad_output"],
fp8_dtype_backward, fp8_dtype_backward,
[dgrad], [dgrad],
ctx.activation_dtype, ctx.activation_dtype,
...@@ -416,7 +403,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -416,7 +403,7 @@ class _GroupedLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
grad_output_t, grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
_GRAD_OUTPUT, ctx.fp8_meta_offsets["grad_output"],
fp8_dtype_backward, fp8_dtype_backward,
wgrad_list, wgrad_list,
ctx.activation_dtype, ctx.activation_dtype,
...@@ -497,12 +484,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -497,12 +484,9 @@ class _GroupedLinear(torch.autograd.Function):
None, # fp8_meta None, # fp8_meta
None, # fuse_wgrad_accumulation None, # fuse_wgrad_accumulation
None, # cpu_offloading None, # cpu_offloading
None, # tp_group
None, # tp_size
None, # sequence_parallel None, # sequence_parallel
None, # tensor_parallel
None, # activation_dtype None, # activation_dtype
None, # parallel_mode None, # fp8_meta_offsets
None, # is_grad_enabled None, # is_grad_enabled
None, # weights_fp8 None, # weights_fp8
*wgrad_list, *wgrad_list,
...@@ -536,23 +520,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -536,23 +520,6 @@ class GroupedLinear(TransformerEngineBaseModule):
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'column', 'row'}, default = `None`
used to decide whether this GroupedLinear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = 'False'
...@@ -613,8 +580,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -613,8 +580,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0}
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, num_gemms, 2 * num_gemms
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -651,7 +617,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -651,7 +617,7 @@ class GroupedLinear(TransformerEngineBaseModule):
), ),
init_fn=init_method, init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=_GEMM_WEIGHT + i, fp8_meta_index=self._offsets["weight"] + i,
) )
# Construct bias parameters if needed # Construct bias parameters if needed
...@@ -774,7 +740,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -774,7 +740,7 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors_fp8[i] = self.get_fp8_workspace( weight_tensors_fp8[i] = self.get_fp8_workspace(
tensor=weight_tensors[i], tensor=weight_tensors[i],
fp8_meta_forward=True, fp8_meta_forward=True,
fp8_meta_index=_GEMM_WEIGHT + i, fp8_meta_index=self._offsets["weight"] + i,
cache_name=(None if is_first_microbatch is None else f"weight{i}"), cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace, update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
...@@ -798,12 +764,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -798,12 +764,9 @@ class GroupedLinear(TransformerEngineBaseModule):
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled, CPUOffloadEnabled,
self.tp_group,
self.tp_size,
self.sequence_parallel, self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self._offsets,
torch.is_grad_enabled(), torch.is_grad_enabled(),
weight_tensors_fp8, weight_tensors_fp8,
*weight_tensors, *weight_tensors,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment