Unverified Commit 6e90fcb7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Upgrade pylint to 3.3.1 (#1257)



* Upgrade pylint and first round formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 2
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 3
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Format and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Run formatter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 161b1d98
...@@ -38,6 +38,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -38,6 +38,7 @@ class _LayerNorm(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -69,6 +70,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -69,6 +70,7 @@ class _LayerNorm(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape) d_ln_out = grad_output.view(inputmat.shape)
...@@ -144,7 +146,7 @@ class LayerNorm(torch.nn.Module): ...@@ -144,7 +146,7 @@ class LayerNorm(torch.nn.Module):
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=device == "meta")
# These many SMs are subtracted from the total SM count when calling forward # These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
...@@ -185,7 +187,7 @@ class LayerNorm(torch.nn.Module): ...@@ -185,7 +187,7 @@ class LayerNorm(torch.nn.Module):
@no_torch_dynamo() @no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD""" # pylint: disable=missing-function-docstring
# Set the activation type for AMP. # Set the activation type for AMP.
# Note: This will soon be deprecated with # Note: This will soon be deprecated with
......
...@@ -94,6 +94,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -94,6 +94,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_output: bool, fp8_output: bool,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
out_features, in_features = weight.shape out_features, in_features = weight.shape
inp_shape = inp.shape inp_shape = inp.shape
...@@ -153,6 +154,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -153,6 +154,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Column Parallel Linear # Column Parallel Linear
ln_out_gathered = False ln_out_gathered = False
ub_algo = None
if ub_overlap_ag: if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
if not return_layernorm_output: if not return_layernorm_output:
...@@ -385,6 +387,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -385,6 +387,7 @@ class _LayerNormLinear(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad_outputs[0], Float8Tensor): if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
0 0
...@@ -479,6 +482,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -479,6 +482,7 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
rs_out = None
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout ub_obj = ub_obj_lnout
...@@ -576,6 +580,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -576,6 +580,7 @@ class _LayerNormLinear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
wgrad = None
if weight.requires_grad: if weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
...@@ -678,6 +683,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -678,6 +683,8 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad) dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgamma = None
dbeta = None
if ctx.normalization == "LayerNorm": if ctx.normalization == "LayerNorm":
dgrad, dgamma, dbeta = tex.layernorm_bwd( dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad, dgrad,
...@@ -1057,7 +1064,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1057,7 +1064,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
if with_fp8_params: if with_fp8_params:
self.init_fp8_metadata() self.init_fp8_metadata()
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
......
...@@ -123,6 +123,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -123,6 +123,7 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_gelu_fusion: bool, gemm_gelu_fusion: bool,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
inp_shape = inp.shape inp_shape = inp.shape
...@@ -173,6 +174,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -173,6 +174,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Column Parallel Linear # Column Parallel Linear
ln_out_gathered = False ln_out_gathered = False
ub_algo_ag = None
if ub_overlap_ag: if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
...@@ -241,23 +243,23 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -241,23 +243,23 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
] ]
fp8_gemm_kwargs = dict( fp8_gemm_kwargs = {
bias=fc1_bias, "bias": fc1_bias,
use_bias=use_fc1_bias, "use_bias": use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP, "use_split_accumulator": _2X_ACC_FPROP,
ub_algo=ub_algo_ag if ub_overlap_ag else None, "ub_algo": ub_algo_ag if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None, "ub": ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None, "extra_output_tensor": ln_out if ub_overlap_ag else None,
) }
if gemm_gelu_fusion: if gemm_gelu_fusion:
fp8_gemm_args[8] = torch.uint8 # out_dtype fp8_gemm_args[8] = torch.uint8 # out_dtype
fp8_gemm_kwargs.update( fp8_gemm_kwargs.update(
dict( {
gelu=True, "gelu": True,
out_index=tex.FP8FwdTensors.GEMM2_INPUT, "out_index": tex.FP8FwdTensors.GEMM2_INPUT,
fp8_meta_tensor=fp8_meta["scaling_fwd"], "fp8_meta_tensor": fp8_meta["scaling_fwd"],
D_dtype=fp8_dtype_forward, "D_dtype": fp8_dtype_forward,
) }
) )
fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs) fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs)
if not is_grad_enabled: if not is_grad_enabled:
...@@ -283,6 +285,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -283,6 +285,9 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
activation_dtype, activation_dtype,
) )
rs_out = None
ub_algo_rs = None
if ub_overlap_rs: if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1) fc2_out = ub_obj_fc2out.get_ubuf_output(1)
...@@ -536,6 +541,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -536,6 +541,7 @@ class _LayerNormMLP(torch.autograd.Function):
def backward( def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...] ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"): with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
( (
inputmat, inputmat,
...@@ -599,6 +605,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -599,6 +605,7 @@ class _LayerNormMLP(torch.autograd.Function):
if tp_world_size == 1: if tp_world_size == 1:
ctx.ub_overlap_ag = False ctx.ub_overlap_ag = False
ub_algo = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
dim_size = list(grad_outputs[0].size()) dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
...@@ -640,6 +647,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -640,6 +647,7 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
fc2_wgrad = None
if ctx.fp8: if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
...@@ -774,6 +782,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -774,6 +782,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
# Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap
rs_out = None
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout ub_obj = ub_obj_lnout
...@@ -923,6 +932,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -923,6 +932,7 @@ class _LayerNormMLP(torch.autograd.Function):
elif ctx.set_parallel_mode and ctx.tensor_parallel: elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
fc1_wgrad = None
if fc1_weight.requires_grad: if fc1_weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# FC1 WGRAD # FC1 WGRAD
...@@ -1026,6 +1036,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1026,6 +1036,8 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad) dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgamma = None
dbeta = None
if ctx.normalization == "LayerNorm": if ctx.normalization == "LayerNorm":
dgrad, dgamma, dbeta = tex.layernorm_bwd( dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad, dgrad,
...@@ -1112,7 +1124,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1112,7 +1124,9 @@ class _LayerNormMLP(torch.autograd.Function):
dbeta, dbeta,
fc1_wgrad, fc1_wgrad,
None, # fc1_weight_fp8 None, # fc1_weight_fp8
fc1_bias_grad if ctx.use_fc1_bias else None, # Due to bias gelu nvfusion available in the bf16 case, fc1_bias_grad is calculated at
# different paths and this confused the linter.
fc1_bias_grad if ctx.use_fc1_bias else None, # pylint: disable=used-before-assignment
None, # use_fc1_bias None, # use_fc1_bias
fc2_wgrad, fc2_wgrad,
None, # fc2_weight_fp8 None, # fc2_weight_fp8
...@@ -1384,7 +1398,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1384,7 +1398,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
if with_fp8_params: if with_fp8_params:
self.init_fp8_metadata(num_gemms=2) self.init_fp8_metadata(num_gemms=2)
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
......
...@@ -85,6 +85,7 @@ class _Linear(torch.autograd.Function): ...@@ -85,6 +85,7 @@ class _Linear(torch.autograd.Function):
fp8_output: bool, fp8_output: bool,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
is_input_fp8 = isinstance(inp, Float8Tensor) is_input_fp8 = isinstance(inp, Float8Tensor)
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
...@@ -177,6 +178,8 @@ class _Linear(torch.autograd.Function): ...@@ -177,6 +178,8 @@ class _Linear(torch.autograd.Function):
activation_dtype, activation_dtype,
) )
ub_algo = None
rs_out = None
if ub_overlap_rs: if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name + "_fprop") ub_obj_projout = get_ub(ub_name + "_fprop")
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
...@@ -364,6 +367,7 @@ class _Linear(torch.autograd.Function): ...@@ -364,6 +367,7 @@ class _Linear(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad_output, Float8Tensor): if isinstance(grad_output, Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[ ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
...@@ -396,6 +400,7 @@ class _Linear(torch.autograd.Function): ...@@ -396,6 +400,7 @@ class _Linear(torch.autograd.Function):
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
ub_algo = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
dim_size = list(grad_output.size()) dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
...@@ -507,6 +512,7 @@ class _Linear(torch.autograd.Function): ...@@ -507,6 +512,7 @@ class _Linear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
wgrad = None
if weight.requires_grad: if weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
...@@ -873,7 +879,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -873,7 +879,7 @@ class Linear(TransformerEngineBaseModule):
if with_fp8_params: if with_fp8_params:
self.init_fp8_metadata() self.init_fp8_metadata()
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
......
...@@ -35,6 +35,7 @@ class _RMSNorm(torch.autograd.Function): ...@@ -35,6 +35,7 @@ class _RMSNorm(torch.autograd.Function):
is_grad_enabled: bool, is_grad_enabled: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = rmsnorm_weight.numel() in_features = rmsnorm_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -61,6 +62,7 @@ class _RMSNorm(torch.autograd.Function): ...@@ -61,6 +62,7 @@ class _RMSNorm(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
d_rmsnorm_out = grad_output.view(inputmat.shape) d_rmsnorm_out = grad_output.view(inputmat.shape)
...@@ -147,7 +149,7 @@ class RMSNorm(torch.nn.Module): ...@@ -147,7 +149,7 @@ class RMSNorm(torch.nn.Module):
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=device == "meta")
# These many SMs are subtracted from the total SM count when calling forward # These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN # and backward RMSNorm C APIs. These envvars can be used to prevent the LN
...@@ -182,7 +184,7 @@ class RMSNorm(torch.nn.Module): ...@@ -182,7 +184,7 @@ class RMSNorm(torch.nn.Module):
@no_torch_dynamo() @no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD""" # pylint: disable=missing-function-docstring
# Set the activation type for AMP. # Set the activation type for AMP.
# Note: This will soon be deprecated with # Note: This will soon be deprecated with
......
...@@ -562,12 +562,12 @@ class BasicLinear(BasicOperation): ...@@ -562,12 +562,12 @@ class BasicLinear(BasicOperation):
_wait_async(x_async) _wait_async(x_async)
x_async = None x_async = None
if with_fp8_compute: if with_fp8_compute:
kwargs = dict( kwargs = {
accumulate=accumulate_into_out, "accumulate": accumulate_into_out,
out=y, "out": y,
bias=b, "bias": b,
use_bias=(b is not None), "use_bias": (b is not None),
) }
if with_fp8_output: if with_fp8_output:
if y._fp8_meta is None: if y._fp8_meta is None:
# Hackily create FP8TensorMeta if needed # Hackily create FP8TensorMeta if needed
...@@ -584,12 +584,12 @@ class BasicLinear(BasicOperation): ...@@ -584,12 +584,12 @@ class BasicLinear(BasicOperation):
fp8_meta = y._fp8_meta[fp8_meta_key] fp8_meta = y._fp8_meta[fp8_meta_key]
fp8_meta_index = y._fp8_meta_index fp8_meta_index = y._fp8_meta_index
kwargs.update( kwargs.update(
dict( {
out=y._data, "out": y._data,
out_index=fp8_meta_index, "out_index": fp8_meta_index,
fp8_meta_tensor=fp8_meta, "fp8_meta_tensor": fp8_meta,
D_dtype=y._fp8_dtype, "D_dtype": y._fp8_dtype,
) }
) )
fp8_gemm( fp8_gemm(
w._data, w._data,
...@@ -936,10 +936,7 @@ class BasicLinear(BasicOperation): ...@@ -936,10 +936,7 @@ class BasicLinear(BasicOperation):
_wait_async(dy_async) _wait_async(dy_async)
dy_async = None dy_async = None
if with_fp8_compute: if with_fp8_compute:
kwargs = dict( kwargs = {"accumulate": accumulate_into_grad_input, "out": dx}
accumulate=accumulate_into_grad_input,
out=dx,
)
if with_fp8_grad_input: if with_fp8_grad_input:
if dx._fp8_meta is None: if dx._fp8_meta is None:
# Hackily create FP8TensorMeta if needed # Hackily create FP8TensorMeta if needed
...@@ -958,12 +955,12 @@ class BasicLinear(BasicOperation): ...@@ -958,12 +955,12 @@ class BasicLinear(BasicOperation):
fp8_meta = dx._fp8_meta[fp8_meta_key] fp8_meta = dx._fp8_meta[fp8_meta_key]
fp8_meta_index = dx._fp8_meta_index fp8_meta_index = dx._fp8_meta_index
kwargs.update( kwargs.update(
dict( {
out=dx._data, "out": dx._data,
out_index=fp8_meta_index, "out_index": fp8_meta_index,
fp8_meta_tensor=fp8_meta, "fp8_meta_tensor": fp8_meta,
D_dtype=dx._fp8_dtype, "D_dtype": dx._fp8_dtype,
) }
) )
fp8_gemm( fp8_gemm(
w.transpose_2d(), w.transpose_2d(),
......
...@@ -38,11 +38,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -38,11 +38,7 @@ class ForwardLinearBiasActivation(FusedOperation):
) -> None: ) -> None:
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
op_idxs = dict( op_idxs = {"linear": 0, "bias": None, "activation": None}
linear=0,
bias=None,
activation=None,
)
ops = [linear] ops = [linear]
if bias is not None: if bias is not None:
op_idxs["bias"] = len(ops) op_idxs["bias"] = len(ops)
......
...@@ -37,11 +37,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -37,11 +37,7 @@ class ForwardLinearBiasAdd(FusedOperation):
) -> None: ) -> None:
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
op_idxs = dict( op_idxs = {"linear": 0, "bias": None, "add": None}
linear=0,
bias=None,
add=None,
)
ops = [linear] ops = [linear]
if bias is not None: if bias is not None:
op_idxs["bias"] = len(ops) op_idxs["bias"] = len(ops)
......
...@@ -91,24 +91,24 @@ class Linear(FusedOperation): ...@@ -91,24 +91,24 @@ class Linear(FusedOperation):
# Construct basic ops # Construct basic ops
ops = [] ops = []
linear_kwargs = dict( linear_kwargs = {
in_features=in_features, "in_features": in_features,
out_features=out_features, "out_features": out_features,
device=device, "device": device,
dtype=dtype, "dtype": dtype,
tensor_parallel_mode=tensor_parallel_mode, "tensor_parallel_mode": tensor_parallel_mode,
tensor_parallel_group=tensor_parallel_group, "tensor_parallel_group": tensor_parallel_group,
sequence_parallel=sequence_parallel, "sequence_parallel": sequence_parallel,
rng_state_tracker_function=rng_state_tracker_function, "rng_state_tracker_function": rng_state_tracker_function,
accumulate_into_main_grad=accumulate_into_main_grad, "accumulate_into_main_grad": accumulate_into_main_grad,
) }
bias_kwargs = dict( bias_kwargs = {
size=out_features, "size": out_features,
device=device, "device": device,
dtype=dtype, "dtype": dtype,
tensor_parallel=(tensor_parallel_mode is not None), "tensor_parallel": (tensor_parallel_mode is not None),
tensor_parallel_group=tensor_parallel_group, "tensor_parallel_group": tensor_parallel_group,
) }
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
# Row TP: GEMM + bias + reduction # Row TP: GEMM + bias + reduction
linear_kwargs["in_features"] = local_in_features linear_kwargs["in_features"] = local_in_features
......
...@@ -179,7 +179,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -179,7 +179,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def is_fused_op(self) -> bool: def is_fused_op(self) -> bool:
return False return False
# pylint: disable=no-self-use
def num_fp8_scales( def num_fp8_scales(
self, self,
mode: str, # pylint: disable=unused-argument mode: str, # pylint: disable=unused-argument
...@@ -225,11 +224,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -225,11 +224,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
} }
# Construct FP8 metadata for all tensor types # Construct FP8 metadata for all tensor types
return dict( return {
input=_make_meta(self.num_fp8_scales("input"), True), "input": _make_meta(self.num_fp8_scales("input"), True),
param=_make_meta(self.num_fp8_scales("param"), True), "param": _make_meta(self.num_fp8_scales("param"), True),
grad_output=_make_meta(self.num_fp8_scales("grad_output"), False), "grad_output": _make_meta(self.num_fp8_scales("grad_output"), False),
) }
@classmethod @classmethod
def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None:
......
...@@ -46,6 +46,7 @@ class Sequential(torch.nn.Module): ...@@ -46,6 +46,7 @@ class Sequential(torch.nn.Module):
self.append(module) self.append(module)
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None: def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
# pylint: disable=missing-function-docstring
self._module_groups = None self._module_groups = None
super().add_module(name, module) super().add_module(name, module)
......
...@@ -100,13 +100,13 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -100,13 +100,13 @@ class FusedAdam(torch.optim.Optimizer):
# If the optimizer is capturable then LR should be a tensor (on GPU) # If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
defaults = dict( defaults = {
lr=lr, "lr": lr,
bias_correction=bias_correction, "bias_correction": bias_correction,
betas=betas, "betas": betas,
eps=eps, "eps": eps,
weight_decay=weight_decay, "weight_decay": weight_decay,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
...@@ -135,6 +135,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -135,6 +135,7 @@ class FusedAdam(torch.optim.Optimizer):
self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master
def zero_grad(self): def zero_grad(self):
# pylint: disable=missing-function-docstring
if self.set_grad_none: if self.set_grad_none:
for group in self.param_groups: for group in self.param_groups:
for p in group["params"]: for p in group["params"]:
......
...@@ -91,13 +91,13 @@ class FusedSGD(Optimizer): ...@@ -91,13 +91,13 @@ class FusedSGD(Optimizer):
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
momentum=momentum, "momentum": momentum,
dampening=dampening, "dampening": dampening,
weight_decay=weight_decay, "weight_decay": weight_decay,
nesterov=nesterov, "nesterov": nesterov,
) }
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults) super().__init__(params, defaults)
...@@ -120,6 +120,7 @@ class FusedSGD(Optimizer): ...@@ -120,6 +120,7 @@ class FusedSGD(Optimizer):
group.setdefault("nesterov", False) group.setdefault("nesterov", False)
def zero_grad(self): def zero_grad(self):
# pylint: disable=missing-function-docstring
if self.set_grad_none: if self.set_grad_none:
for group in self.param_groups: for group in self.param_groups:
for p in group["params"]: for p in group["params"]:
......
...@@ -32,6 +32,7 @@ class _moe_permute(torch.autograd.Function): ...@@ -32,6 +32,7 @@ class _moe_permute(torch.autograd.Function):
num_out_tokens: int, num_out_tokens: int,
max_token_num: int, max_token_num: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
# Empty input check # Empty input check
if not inp.numel(): if not inp.numel():
return inp, torch.tensor([], device=inp.device) return inp, torch.tensor([], device=inp.device)
...@@ -90,6 +91,7 @@ class _moe_permute(torch.autograd.Function): ...@@ -90,6 +91,7 @@ class _moe_permute(torch.autograd.Function):
permuted_act_grad: torch.Tensor, permuted_act_grad: torch.Tensor,
_, _,
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
# Empty input check # Empty input check
if not permuted_act_grad.numel(): if not permuted_act_grad.numel():
return permuted_act_grad, None, None, None return permuted_act_grad, None, None, None
...@@ -130,6 +132,7 @@ class _moe_unpermute(torch.autograd.Function): ...@@ -130,6 +132,7 @@ class _moe_unpermute(torch.autograd.Function):
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
probs: torch.Tensor, probs: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Empty input check # Empty input check
if not inp.numel(): if not inp.numel():
ctx.probs = probs ctx.probs = probs
...@@ -188,6 +191,7 @@ class _moe_unpermute(torch.autograd.Function): ...@@ -188,6 +191,7 @@ class _moe_unpermute(torch.autograd.Function):
ctx, ctx,
unpermuted_act_grad: torch.Tensor, unpermuted_act_grad: torch.Tensor,
) -> Tuple[torch.Tensor, None, torch.Tensor]: ) -> Tuple[torch.Tensor, None, torch.Tensor]:
# pylint: disable=missing-function-docstring
# Empty input check # Empty input check
if not unpermuted_act_grad.numel(): if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.probs return unpermuted_act_grad, None, ctx.probs
...@@ -208,6 +212,7 @@ class _moe_unpermute(torch.autograd.Function): ...@@ -208,6 +212,7 @@ class _moe_unpermute(torch.autograd.Function):
inp, row_id_map, probs = ctx.saved_tensors inp, row_id_map, probs = ctx.saved_tensors
act_grad = None act_grad = None
prob_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
act_grad, prob_grad = tex.moe_unpermute_bwd( act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, dtype, row_id_map, probs unpermuted_act_grad, inp, dtype, row_id_map, probs
......
...@@ -49,7 +49,7 @@ def _make_fp8_attr_property_funcs(name: str) -> Any: ...@@ -49,7 +49,7 @@ def _make_fp8_attr_property_funcs(name: str) -> Any:
def del_func(self) -> None: def del_func(self) -> None:
del self._fp8_attrs[name] del self._fp8_attrs[name]
return dict(fget=get_func, fset=set_func, fdel=del_func) return {"fget": get_func, "fset": set_func, "fdel": del_func}
class _FromFloat8Func(torch.autograd.Function): class _FromFloat8Func(torch.autograd.Function):
...@@ -61,6 +61,7 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -61,6 +61,7 @@ class _FromFloat8Func(torch.autograd.Function):
tensor: Float8Tensor, tensor: Float8Tensor,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return tensor.dequantize(dtype=dtype) return tensor.dequantize(dtype=dtype)
@staticmethod @staticmethod
...@@ -68,6 +69,7 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -68,6 +69,7 @@ class _FromFloat8Func(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused _ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor, grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision # Assume that we want gradients in full precision
return grad, None return grad, None
...@@ -112,6 +114,7 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -112,6 +114,7 @@ class _ToFloat8Func(torch.autograd.Function):
scale_inv: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None,
with_transpose_cache: bool = False, with_transpose_cache: bool = False,
) -> Float8Tensor: ) -> Float8Tensor:
# pylint: disable=missing-function-docstring
# Tensor attributes # Tensor attributes
dtype = tensor.dtype dtype = tensor.dtype
...@@ -167,6 +170,7 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -167,6 +170,7 @@ class _ToFloat8Func(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused _ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor, grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision # Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None return grad, None, None, None, None, None, None, None
...@@ -185,6 +189,7 @@ class _IdentityFunc(torch.autograd.Function): ...@@ -185,6 +189,7 @@ class _IdentityFunc(torch.autograd.Function):
tensor: Float8Tensor, tensor: Float8Tensor,
init_kwargs: Optional[Dict[str, Any]] = None, init_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided # Return input tensor if constructor kwargs are not provided
ctx.input_dtype = tensor.dtype ctx.input_dtype = tensor.dtype
...@@ -192,15 +197,15 @@ class _IdentityFunc(torch.autograd.Function): ...@@ -192,15 +197,15 @@ class _IdentityFunc(torch.autograd.Function):
return tensor return tensor
# Construct new tensor if constructor kwargs are provided # Construct new tensor if constructor kwargs are provided
default_kwargs = dict( default_kwargs = {
data=tensor._data, "data": tensor._data,
fp8_meta=tensor._fp8_meta, "fp8_meta": tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward, "fp8_meta_forward": tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index, "fp8_meta_index": tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype, "fp8_dtype": tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv, "fp8_scale_inv": tensor._scale_inv,
dtype=tensor.dtype, "dtype": tensor.dtype,
) }
for key, val in default_kwargs.items(): for key, val in default_kwargs.items():
if key not in init_kwargs: if key not in init_kwargs:
init_kwargs[key] = val init_kwargs[key] = val
...@@ -208,6 +213,7 @@ class _IdentityFunc(torch.autograd.Function): ...@@ -208,6 +213,7 @@ class _IdentityFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad):
# pylint: disable=missing-function-docstring
return grad.to(ctx.input_dtype), None return grad.to(ctx.input_dtype), None
...@@ -224,6 +230,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -224,6 +230,7 @@ class _ViewFunc(torch.autograd.Function):
tensor: torch.Tensor, tensor: torch.Tensor,
shape: Tuple[int] = None, shape: Tuple[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
...@@ -243,6 +250,7 @@ class _ViewFunc(torch.autograd.Function): ...@@ -243,6 +250,7 @@ class _ViewFunc(torch.autograd.Function):
ctx, ctx,
grad: torch.Tensor, grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8Tensor): if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like( dgrad = Float8Tensor.make_like(
...@@ -266,6 +274,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -266,6 +274,7 @@ class _ReshapeFunc(torch.autograd.Function):
tensor: torch.Tensor, tensor: torch.Tensor,
shape: Tuple[int] = None, shape: Tuple[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided # Return input tensor if shape is not provided
ctx.shape = tensor.shape ctx.shape = tensor.shape
...@@ -285,6 +294,7 @@ class _ReshapeFunc(torch.autograd.Function): ...@@ -285,6 +294,7 @@ class _ReshapeFunc(torch.autograd.Function):
ctx, ctx,
grad: torch.Tensor, grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8Tensor): if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like( dgrad = Float8Tensor.make_like(
...@@ -456,14 +466,14 @@ class Float8Tensor(QuantizedTensor): ...@@ -456,14 +466,14 @@ class Float8Tensor(QuantizedTensor):
See constructor for list of keyword arguments. See constructor for list of keyword arguments.
""" """
default_kwargs = dict( default_kwargs = {
fp8_meta=tensor._fp8_meta, "fp8_meta": tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward, "fp8_meta_forward": tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index, "fp8_meta_index": tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype, "fp8_dtype": tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv, "fp8_scale_inv": tensor._scale_inv,
dtype=tensor.dtype, "dtype": tensor.dtype,
) }
for key, val in default_kwargs.items(): for key, val in default_kwargs.items():
if key not in kwargs: if key not in kwargs:
kwargs[key] = val kwargs[key] = val
...@@ -697,6 +707,7 @@ class Float8Tensor(QuantizedTensor): ...@@ -697,6 +707,7 @@ class Float8Tensor(QuantizedTensor):
) )
def detach(self) -> Float8Tensor: def detach(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring
return Float8Tensor.make_like( return Float8Tensor.make_like(
self, self,
data=self._data, data=self._data,
...@@ -704,22 +715,25 @@ class Float8Tensor(QuantizedTensor): ...@@ -704,22 +715,25 @@ class Float8Tensor(QuantizedTensor):
) )
def clone(self) -> Float8Tensor: def clone(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring
data = self._data.detach().clone() data = self._data.detach().clone()
data_transpose = None data_transpose = None
if self._transpose is not None: if self._transpose is not None:
data_transpose = self._transpose.detach().clone() data_transpose = self._transpose.detach().clone()
return _IdentityFunc.apply( return _IdentityFunc.apply(
self, self,
dict( {
data=data, "data": data,
data_transpose=data_transpose, "data_transpose": data_transpose,
), },
) )
def view(self, *shape: Tuple[int]) -> Float8Tensor: def view(self, *shape: Tuple[int]) -> Float8Tensor:
# pylint: disable=missing-function-docstring
return _ViewFunc.apply(self, shape) return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8Tensor: def reshape(self, *shape: Tuple[int]) -> Float8Tensor:
# pylint: disable=missing-function-docstring
return _ReshapeFunc.apply(self, shape) return _ReshapeFunc.apply(self, shape)
def contiguous( def contiguous(
...@@ -980,6 +994,7 @@ class Float8Tensor(QuantizedTensor): ...@@ -980,6 +994,7 @@ class Float8Tensor(QuantizedTensor):
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
device=new_device, device=new_device,
) )
# pylint: disable=unnecessary-dunder-call
super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor)
self._data = tensor._data self._data = tensor._data
self._fp8_attrs = tensor._fp8_attrs self._fp8_attrs = tensor._fp8_attrs
...@@ -1008,6 +1023,7 @@ class Float8Tensor(QuantizedTensor): ...@@ -1008,6 +1023,7 @@ class Float8Tensor(QuantizedTensor):
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
device=self._data.device, device=self._data.device,
) )
# pylint: disable=unnecessary-dunder-call
super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor)
if self._transpose is not None: if self._transpose is not None:
self._transpose = torch.empty( self._transpose = torch.empty(
......
...@@ -20,6 +20,7 @@ class _DequantizeFunc(torch.autograd.Function): ...@@ -20,6 +20,7 @@ class _DequantizeFunc(torch.autograd.Function):
tensor: QuantizedTensor, tensor: QuantizedTensor,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return tensor.dequantize(dtype=dtype) return tensor.dequantize(dtype=dtype)
@staticmethod @staticmethod
...@@ -27,6 +28,7 @@ class _DequantizeFunc(torch.autograd.Function): ...@@ -27,6 +28,7 @@ class _DequantizeFunc(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused _ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor, grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]: ) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
return grad, None return grad, None
...@@ -38,6 +40,7 @@ class _IdentityFunc(torch.autograd.Function): ...@@ -38,6 +40,7 @@ class _IdentityFunc(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused _ctx: torch.autograd.function.FunctionCtx, # unused
tensor: QuantizedTensor, tensor: QuantizedTensor,
) -> QuantizedTensor: ) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return tensor.detach() return tensor.detach()
@staticmethod @staticmethod
...@@ -45,6 +48,7 @@ class _IdentityFunc(torch.autograd.Function): ...@@ -45,6 +48,7 @@ class _IdentityFunc(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused _ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor, grad: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return grad return grad
...@@ -85,18 +89,23 @@ class QuantizedTensor(torch.Tensor): ...@@ -85,18 +89,23 @@ class QuantizedTensor(torch.Tensor):
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
def float(self) -> torch.Tensor: def float(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return _DequantizeFunc.apply(self, torch.float32) return _DequantizeFunc.apply(self, torch.float32)
def bfloat16(self) -> torch.Tensor: def bfloat16(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return _DequantizeFunc.apply(self, torch.bfloat16) return _DequantizeFunc.apply(self, torch.bfloat16)
def half(self) -> torch.Tensor: def half(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return _DequantizeFunc.apply(self, torch.float16) return _DequantizeFunc.apply(self, torch.float16)
def cpu(self) -> torch.Tensor: def cpu(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return _DequantizeFunc.apply(self).cpu() return _DequantizeFunc.apply(self).cpu()
def expand_as(self, other: torch.Tensor) -> torch.Tensor: def expand_as(self, other: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if other is self: if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes # Note: expand_as is hackily used to create dummy autograd nodes
# and access the backward graph (see # and access the backward graph (see
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment