Unverified Commit 64a3d1d5 authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Make user buffer name configurable (#499)



* Make user buffer name configurable
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Fix duplicate argument
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix autograd
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c706ff8d
...@@ -86,6 +86,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -86,6 +86,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_ag: bool, ub_split_ag: bool,
ub_atomic_gemm_ag: bool, ub_atomic_gemm_ag: bool,
ub_name: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -111,7 +112,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -111,7 +112,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ub_split_ag or ub_atomic_gemm_ag: if ub_split_ag or ub_atomic_gemm_ag:
dim_size = list(inputmat.size()) dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop") ub_obj_lnout = get_ub(ub_name+"_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else: else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
...@@ -268,6 +269,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -268,6 +269,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
...@@ -310,7 +312,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -310,7 +312,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size()) dim_size = list(ln_out.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_dgrad") ub_obj_lnout = get_ub(ctx.ub_name+"_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
( (
grad_output, grad_output,
...@@ -350,7 +352,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -350,7 +352,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_size = list(grad_output.size()) dgrad_size = list(grad_output.size())
dgrad_size[1] = weight.size(1) dgrad_size[1] = weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("qkv_wgrad") ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else: else:
dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device) dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
...@@ -567,6 +569,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -567,6 +569,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -674,6 +677,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -674,6 +677,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
ub_atomic_gemm_ag: bool = False, ub_atomic_gemm_ag: bool = False,
ub_name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -694,6 +698,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -694,6 +698,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag]):
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag: if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag:
assert ( assert (
...@@ -978,6 +986,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -978,6 +986,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_ag, self.ub_split_ag,
self.ub_atomic_gemm_ag, self.ub_atomic_gemm_ag,
self.ub_name,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -82,6 +82,7 @@ class _Linear(torch.autograd.Function): ...@@ -82,6 +82,7 @@ class _Linear(torch.autograd.Function):
ub_split_ag: bool, ub_split_ag: bool,
ub_atomic_gemm_rs: bool, ub_atomic_gemm_rs: bool,
ub_atomic_gemm_ag: bool, ub_atomic_gemm_ag: bool,
ub_name: str,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] in_features = weight.shape[-1]
...@@ -180,7 +181,7 @@ class _Linear(torch.autograd.Function): ...@@ -180,7 +181,7 @@ class _Linear(torch.autograd.Function):
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype) None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs: if ub_split_rs or ub_atomic_gemm_rs:
ub_obj_projout = get_ub("proj_fprop") ub_obj_projout = get_ub(ub_name+"_fprop")
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
...@@ -285,6 +286,7 @@ class _Linear(torch.autograd.Function): ...@@ -285,6 +286,7 @@ class _Linear(torch.autograd.Function):
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag ctx.ub_split_ag = ub_split_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.ub_name = ub_name
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
...@@ -326,7 +328,7 @@ class _Linear(torch.autograd.Function): ...@@ -326,7 +328,7 @@ class _Linear(torch.autograd.Function):
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
dim_size = list(grad_output.size()) dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("proj_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad")
( (
grad_output, grad_output,
grad_output_c, grad_output_c,
...@@ -499,6 +501,7 @@ class _Linear(torch.autograd.Function): ...@@ -499,6 +501,7 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -588,6 +591,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -588,6 +591,7 @@ class Linear(TransformerEngineBaseModule):
ub_split_ag: bool = False, ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False, ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False, ub_atomic_gemm_ag: bool = False,
ub_name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -604,6 +608,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -604,6 +608,9 @@ class Linear(TransformerEngineBaseModule):
self.ub_split_ag = ub_split_ag self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]):
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs: if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs:
assert ( assert (
...@@ -848,6 +855,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -848,6 +855,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_split_ag, self.ub_split_ag,
self.ub_atomic_gemm_rs, self.ub_atomic_gemm_rs,
self.ub_atomic_gemm_ag, self.ub_atomic_gemm_ag,
self.ub_name,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment