"vscode:/vscode.git/clone" did not exist on "5eb712f96256278a6ffdbebb53d551c2aa2a8e45"
Unverified Commit 64a3d1d5 authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Make user buffer name configurable (#499)



* Make user buffer name configurable
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Fix duplicate argument
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix autograd
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c706ff8d
......@@ -86,6 +86,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_dgrad: bool,
ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
ub_name: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -111,7 +112,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ub_split_ag or ub_atomic_gemm_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop")
ub_obj_lnout = get_ub(ub_name+"_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
......@@ -268,6 +269,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
......@@ -310,7 +312,7 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_dgrad")
ub_obj_lnout = get_ub(ctx.ub_name+"_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
(
grad_output,
......@@ -350,7 +352,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad_size = list(grad_output.size())
dgrad_size[1] = weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("qkv_wgrad")
ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
......@@ -567,6 +569,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -674,6 +677,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -694,6 +698,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag]):
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag:
assert (
......@@ -978,6 +986,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_dgrad,
self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.ub_name,
)
out = fwd_fn(*args)
......
......@@ -82,6 +82,7 @@ class _Linear(torch.autograd.Function):
ub_split_ag: bool,
ub_atomic_gemm_rs: bool,
ub_atomic_gemm_ag: bool,
ub_name: str,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -180,7 +181,7 @@ class _Linear(torch.autograd.Function):
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs:
ub_obj_projout = get_ub("proj_fprop")
ub_obj_projout = get_ub(ub_name+"_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
......@@ -285,6 +286,7 @@ class _Linear(torch.autograd.Function):
ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.ub_name = ub_name
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
......@@ -326,7 +328,7 @@ class _Linear(torch.autograd.Function):
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("proj_dgrad")
ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad")
(
grad_output,
grad_output_c,
......@@ -499,6 +501,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -588,6 +591,7 @@ class Linear(TransformerEngineBaseModule):
ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -604,6 +608,9 @@ class Linear(TransformerEngineBaseModule):
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]):
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs:
assert (
......@@ -848,6 +855,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_split_ag,
self.ub_atomic_gemm_rs,
self.ub_atomic_gemm_ag,
self.ub_name,
)
out = linear_fn(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment