Unverified Commit 6e90fcb7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Upgrade pylint to 3.3.1 (#1257)



* Upgrade pylint and first round formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 2
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* round 3
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Format and fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Reviews
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIxes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Run formatter
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Paddle lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 161b1d98
......@@ -38,6 +38,7 @@ class _LayerNorm(torch.autograd.Function):
is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
......@@ -69,6 +70,7 @@ class _LayerNorm(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
......@@ -144,7 +146,7 @@ class LayerNorm(torch.nn.Module):
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=(device == "meta"))
self.reset_parameters(defer_init=device == "meta")
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
......@@ -185,7 +187,7 @@ class LayerNorm(torch.nn.Module):
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# pylint: disable=missing-function-docstring
# Set the activation type for AMP.
# Note: This will soon be deprecated with
......
......@@ -94,6 +94,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_output: bool,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
......@@ -153,6 +154,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Column Parallel Linear
ln_out_gathered = False
ub_algo = None
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
if not return_layernorm_output:
......@@ -385,6 +387,7 @@ class _LayerNormLinear(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
0
......@@ -479,6 +482,7 @@ class _LayerNormLinear(torch.autograd.Function):
else:
dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
rs_out = None
if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout
......@@ -576,6 +580,7 @@ class _LayerNormLinear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
wgrad = None
if weight.requires_grad:
if ctx.fp8:
# WGRAD
......@@ -678,6 +683,8 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgamma = None
dbeta = None
if ctx.normalization == "LayerNorm":
dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad,
......@@ -1057,7 +1064,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
if with_fp8_params:
self.init_fp8_metadata()
self.reset_parameters(defer_init=(device == "meta"))
self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......
......@@ -123,6 +123,7 @@ class _LayerNormMLP(torch.autograd.Function):
gemm_gelu_fusion: bool,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
inp_shape = inp.shape
......@@ -173,6 +174,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Column Parallel Linear
ln_out_gathered = False
ub_algo_ag = None
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
......@@ -241,23 +243,23 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype,
get_workspace(),
]
fp8_gemm_kwargs = dict(
bias=fc1_bias,
use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=ub_algo_ag if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
fp8_gemm_kwargs = {
"bias": fc1_bias,
"use_bias": use_fc1_bias,
"use_split_accumulator": _2X_ACC_FPROP,
"ub_algo": ub_algo_ag if ub_overlap_ag else None,
"ub": ub_obj_lnout if ub_overlap_ag else None,
"extra_output_tensor": ln_out if ub_overlap_ag else None,
}
if gemm_gelu_fusion:
fp8_gemm_args[8] = torch.uint8 # out_dtype
fp8_gemm_kwargs.update(
dict(
gelu=True,
out_index=tex.FP8FwdTensors.GEMM2_INPUT,
fp8_meta_tensor=fp8_meta["scaling_fwd"],
D_dtype=fp8_dtype_forward,
)
{
"gelu": True,
"out_index": tex.FP8FwdTensors.GEMM2_INPUT,
"fp8_meta_tensor": fp8_meta["scaling_fwd"],
"D_dtype": fp8_dtype_forward,
}
)
fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs)
if not is_grad_enabled:
......@@ -283,6 +285,9 @@ class _LayerNormMLP(torch.autograd.Function):
None,
activation_dtype,
)
rs_out = None
ub_algo_rs = None
if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1)
......@@ -536,6 +541,7 @@ class _LayerNormMLP(torch.autograd.Function):
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
(
inputmat,
......@@ -599,6 +605,7 @@ class _LayerNormMLP(torch.autograd.Function):
if tp_world_size == 1:
ctx.ub_overlap_ag = False
ub_algo = None
if ctx.ub_overlap_ag:
dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size
......@@ -640,6 +647,7 @@ class _LayerNormMLP(torch.autograd.Function):
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
fc2_wgrad = None
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
......@@ -774,6 +782,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
# Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap
rs_out = None
if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout
......@@ -923,6 +932,7 @@ class _LayerNormMLP(torch.autograd.Function):
elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
fc1_wgrad = None
if fc1_weight.requires_grad:
if ctx.fp8:
# FC1 WGRAD
......@@ -1026,6 +1036,8 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgamma = None
dbeta = None
if ctx.normalization == "LayerNorm":
dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad,
......@@ -1112,7 +1124,9 @@ class _LayerNormMLP(torch.autograd.Function):
dbeta,
fc1_wgrad,
None, # fc1_weight_fp8
fc1_bias_grad if ctx.use_fc1_bias else None,
# Due to bias gelu nvfusion available in the bf16 case, fc1_bias_grad is calculated at
# different paths and this confused the linter.
fc1_bias_grad if ctx.use_fc1_bias else None, # pylint: disable=used-before-assignment
None, # use_fc1_bias
fc2_wgrad,
None, # fc2_weight_fp8
......@@ -1384,7 +1398,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
if with_fp8_params:
self.init_fp8_metadata(num_gemms=2)
self.reset_parameters(defer_init=(device == "meta"))
self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......
......@@ -85,6 +85,7 @@ class _Linear(torch.autograd.Function):
fp8_output: bool,
fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
is_input_fp8 = isinstance(inp, Float8Tensor)
# Make sure input dimensions are compatible
......@@ -177,6 +178,8 @@ class _Linear(torch.autograd.Function):
activation_dtype,
)
ub_algo = None
rs_out = None
if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name + "_fprop")
out = ub_obj_projout.get_ubuf_output(1)
......@@ -364,6 +367,7 @@ class _Linear(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad_output, Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
......@@ -396,6 +400,7 @@ class _Linear(torch.autograd.Function):
tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
ub_algo = None
if ctx.ub_overlap_ag:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size
......@@ -507,6 +512,7 @@ class _Linear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
wgrad = None
if weight.requires_grad:
if ctx.fp8:
# WGRAD
......@@ -873,7 +879,7 @@ class Linear(TransformerEngineBaseModule):
if with_fp8_params:
self.init_fp8_metadata()
self.reset_parameters(defer_init=(device == "meta"))
self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
......
......@@ -35,6 +35,7 @@ class _RMSNorm(torch.autograd.Function):
is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = rmsnorm_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
......@@ -61,6 +62,7 @@ class _RMSNorm(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_rmsnorm_out = grad_output.view(inputmat.shape)
......@@ -147,7 +149,7 @@ class RMSNorm(torch.nn.Module):
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=(device == "meta"))
self.reset_parameters(defer_init=device == "meta")
# These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN
......@@ -182,7 +184,7 @@ class RMSNorm(torch.nn.Module):
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD"""
# pylint: disable=missing-function-docstring
# Set the activation type for AMP.
# Note: This will soon be deprecated with
......
......@@ -562,12 +562,12 @@ class BasicLinear(BasicOperation):
_wait_async(x_async)
x_async = None
if with_fp8_compute:
kwargs = dict(
accumulate=accumulate_into_out,
out=y,
bias=b,
use_bias=(b is not None),
)
kwargs = {
"accumulate": accumulate_into_out,
"out": y,
"bias": b,
"use_bias": (b is not None),
}
if with_fp8_output:
if y._fp8_meta is None:
# Hackily create FP8TensorMeta if needed
......@@ -584,12 +584,12 @@ class BasicLinear(BasicOperation):
fp8_meta = y._fp8_meta[fp8_meta_key]
fp8_meta_index = y._fp8_meta_index
kwargs.update(
dict(
out=y._data,
out_index=fp8_meta_index,
fp8_meta_tensor=fp8_meta,
D_dtype=y._fp8_dtype,
)
{
"out": y._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": y._fp8_dtype,
}
)
fp8_gemm(
w._data,
......@@ -936,10 +936,7 @@ class BasicLinear(BasicOperation):
_wait_async(dy_async)
dy_async = None
if with_fp8_compute:
kwargs = dict(
accumulate=accumulate_into_grad_input,
out=dx,
)
kwargs = {"accumulate": accumulate_into_grad_input, "out": dx}
if with_fp8_grad_input:
if dx._fp8_meta is None:
# Hackily create FP8TensorMeta if needed
......@@ -958,12 +955,12 @@ class BasicLinear(BasicOperation):
fp8_meta = dx._fp8_meta[fp8_meta_key]
fp8_meta_index = dx._fp8_meta_index
kwargs.update(
dict(
out=dx._data,
out_index=fp8_meta_index,
fp8_meta_tensor=fp8_meta,
D_dtype=dx._fp8_dtype,
)
{
"out": dx._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": dx._fp8_dtype,
}
)
fp8_gemm(
w.transpose_2d(),
......
......@@ -38,11 +38,7 @@ class ForwardLinearBiasActivation(FusedOperation):
) -> None:
# Basic operations that comprise this fused operation
op_idxs = dict(
linear=0,
bias=None,
activation=None,
)
op_idxs = {"linear": 0, "bias": None, "activation": None}
ops = [linear]
if bias is not None:
op_idxs["bias"] = len(ops)
......
......@@ -37,11 +37,7 @@ class ForwardLinearBiasAdd(FusedOperation):
) -> None:
# Basic operations that comprise this fused operation
op_idxs = dict(
linear=0,
bias=None,
add=None,
)
op_idxs = {"linear": 0, "bias": None, "add": None}
ops = [linear]
if bias is not None:
op_idxs["bias"] = len(ops)
......
......@@ -91,24 +91,24 @@ class Linear(FusedOperation):
# Construct basic ops
ops = []
linear_kwargs = dict(
in_features=in_features,
out_features=out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=tensor_parallel_group,
sequence_parallel=sequence_parallel,
rng_state_tracker_function=rng_state_tracker_function,
accumulate_into_main_grad=accumulate_into_main_grad,
)
bias_kwargs = dict(
size=out_features,
device=device,
dtype=dtype,
tensor_parallel=(tensor_parallel_mode is not None),
tensor_parallel_group=tensor_parallel_group,
)
linear_kwargs = {
"in_features": in_features,
"out_features": out_features,
"device": device,
"dtype": dtype,
"tensor_parallel_mode": tensor_parallel_mode,
"tensor_parallel_group": tensor_parallel_group,
"sequence_parallel": sequence_parallel,
"rng_state_tracker_function": rng_state_tracker_function,
"accumulate_into_main_grad": accumulate_into_main_grad,
}
bias_kwargs = {
"size": out_features,
"device": device,
"dtype": dtype,
"tensor_parallel": (tensor_parallel_mode is not None),
"tensor_parallel_group": tensor_parallel_group,
}
if tensor_parallel_mode == "row":
# Row TP: GEMM + bias + reduction
linear_kwargs["in_features"] = local_in_features
......
......@@ -179,7 +179,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def is_fused_op(self) -> bool:
return False
# pylint: disable=no-self-use
def num_fp8_scales(
self,
mode: str, # pylint: disable=unused-argument
......@@ -225,11 +224,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
}
# Construct FP8 metadata for all tensor types
return dict(
input=_make_meta(self.num_fp8_scales("input"), True),
param=_make_meta(self.num_fp8_scales("param"), True),
grad_output=_make_meta(self.num_fp8_scales("grad_output"), False),
)
return {
"input": _make_meta(self.num_fp8_scales("input"), True),
"param": _make_meta(self.num_fp8_scales("param"), True),
"grad_output": _make_meta(self.num_fp8_scales("grad_output"), False),
}
@classmethod
def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None:
......
......@@ -46,6 +46,7 @@ class Sequential(torch.nn.Module):
self.append(module)
def add_module(self, name: str, module: Optional[torch.nn.Module]) -> None:
# pylint: disable=missing-function-docstring
self._module_groups = None
super().add_module(name, module)
......
......@@ -100,13 +100,13 @@ class FusedAdam(torch.optim.Optimizer):
# If the optimizer is capturable then LR should be a tensor (on GPU)
lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr
defaults = dict(
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
defaults = {
"lr": lr,
"bias_correction": bias_correction,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
}
super().__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
......@@ -135,6 +135,7 @@ class FusedAdam(torch.optim.Optimizer):
self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master
def zero_grad(self):
# pylint: disable=missing-function-docstring
if self.set_grad_none:
for group in self.param_groups:
for p in group["params"]:
......
......@@ -91,13 +91,13 @@ class FusedSGD(Optimizer):
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
defaults = {
"lr": lr,
"momentum": momentum,
"dampening": dampening,
"weight_decay": weight_decay,
"nesterov": nesterov,
}
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
......@@ -120,6 +120,7 @@ class FusedSGD(Optimizer):
group.setdefault("nesterov", False)
def zero_grad(self):
# pylint: disable=missing-function-docstring
if self.set_grad_none:
for group in self.param_groups:
for p in group["params"]:
......
......@@ -32,6 +32,7 @@ class _moe_permute(torch.autograd.Function):
num_out_tokens: int,
max_token_num: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
# Empty input check
if not inp.numel():
return inp, torch.tensor([], device=inp.device)
......@@ -90,6 +91,7 @@ class _moe_permute(torch.autograd.Function):
permuted_act_grad: torch.Tensor,
_,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
# Empty input check
if not permuted_act_grad.numel():
return permuted_act_grad, None, None, None
......@@ -130,6 +132,7 @@ class _moe_unpermute(torch.autograd.Function):
row_id_map: torch.Tensor,
probs: torch.Tensor,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Empty input check
if not inp.numel():
ctx.probs = probs
......@@ -188,6 +191,7 @@ class _moe_unpermute(torch.autograd.Function):
ctx,
unpermuted_act_grad: torch.Tensor,
) -> Tuple[torch.Tensor, None, torch.Tensor]:
# pylint: disable=missing-function-docstring
# Empty input check
if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.probs
......@@ -208,6 +212,7 @@ class _moe_unpermute(torch.autograd.Function):
inp, row_id_map, probs = ctx.saved_tensors
act_grad = None
prob_grad = None
if ctx.needs_input_grad[0]:
act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, dtype, row_id_map, probs
......
......@@ -49,7 +49,7 @@ def _make_fp8_attr_property_funcs(name: str) -> Any:
def del_func(self) -> None:
del self._fp8_attrs[name]
return dict(fget=get_func, fset=set_func, fdel=del_func)
return {"fget": get_func, "fset": set_func, "fdel": del_func}
class _FromFloat8Func(torch.autograd.Function):
......@@ -61,6 +61,7 @@ class _FromFloat8Func(torch.autograd.Function):
tensor: Float8Tensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return tensor.dequantize(dtype=dtype)
@staticmethod
......@@ -68,6 +69,7 @@ class _FromFloat8Func(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None
......@@ -112,6 +114,7 @@ class _ToFloat8Func(torch.autograd.Function):
scale_inv: Optional[torch.Tensor] = None,
with_transpose_cache: bool = False,
) -> Float8Tensor:
# pylint: disable=missing-function-docstring
# Tensor attributes
dtype = tensor.dtype
......@@ -167,6 +170,7 @@ class _ToFloat8Func(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
# Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None
......@@ -185,6 +189,7 @@ class _IdentityFunc(torch.autograd.Function):
tensor: Float8Tensor,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if constructor kwargs are not provided
ctx.input_dtype = tensor.dtype
......@@ -192,15 +197,15 @@ class _IdentityFunc(torch.autograd.Function):
return tensor
# Construct new tensor if constructor kwargs are provided
default_kwargs = dict(
data=tensor._data,
fp8_meta=tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv,
dtype=tensor.dtype,
)
default_kwargs = {
"data": tensor._data,
"fp8_meta": tensor._fp8_meta,
"fp8_meta_forward": tensor._fp8_meta_forward,
"fp8_meta_index": tensor._fp8_meta_index,
"fp8_dtype": tensor._fp8_dtype,
"fp8_scale_inv": tensor._scale_inv,
"dtype": tensor.dtype,
}
for key, val in default_kwargs.items():
if key not in init_kwargs:
init_kwargs[key] = val
......@@ -208,6 +213,7 @@ class _IdentityFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, grad):
# pylint: disable=missing-function-docstring
return grad.to(ctx.input_dtype), None
......@@ -224,6 +230,7 @@ class _ViewFunc(torch.autograd.Function):
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
......@@ -243,6 +250,7 @@ class _ViewFunc(torch.autograd.Function):
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
......@@ -266,6 +274,7 @@ class _ReshapeFunc(torch.autograd.Function):
tensor: torch.Tensor,
shape: Tuple[int] = None,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
ctx.shape = tensor.shape
......@@ -285,6 +294,7 @@ class _ReshapeFunc(torch.autograd.Function):
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8Tensor):
dgrad = Float8Tensor.make_like(
......@@ -456,14 +466,14 @@ class Float8Tensor(QuantizedTensor):
See constructor for list of keyword arguments.
"""
default_kwargs = dict(
fp8_meta=tensor._fp8_meta,
fp8_meta_forward=tensor._fp8_meta_forward,
fp8_meta_index=tensor._fp8_meta_index,
fp8_dtype=tensor._fp8_dtype,
fp8_scale_inv=tensor._scale_inv,
dtype=tensor.dtype,
)
default_kwargs = {
"fp8_meta": tensor._fp8_meta,
"fp8_meta_forward": tensor._fp8_meta_forward,
"fp8_meta_index": tensor._fp8_meta_index,
"fp8_dtype": tensor._fp8_dtype,
"fp8_scale_inv": tensor._scale_inv,
"dtype": tensor.dtype,
}
for key, val in default_kwargs.items():
if key not in kwargs:
kwargs[key] = val
......@@ -697,6 +707,7 @@ class Float8Tensor(QuantizedTensor):
)
def detach(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring
return Float8Tensor.make_like(
self,
data=self._data,
......@@ -704,22 +715,25 @@ class Float8Tensor(QuantizedTensor):
)
def clone(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring
data = self._data.detach().clone()
data_transpose = None
if self._transpose is not None:
data_transpose = self._transpose.detach().clone()
return _IdentityFunc.apply(
self,
dict(
data=data,
data_transpose=data_transpose,
),
{
"data": data,
"data_transpose": data_transpose,
},
)
def view(self, *shape: Tuple[int]) -> Float8Tensor:
# pylint: disable=missing-function-docstring
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8Tensor:
# pylint: disable=missing-function-docstring
return _ReshapeFunc.apply(self, shape)
def contiguous(
......@@ -980,6 +994,7 @@ class Float8Tensor(QuantizedTensor):
requires_grad=tensor.requires_grad,
device=new_device,
)
# pylint: disable=unnecessary-dunder-call
super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor)
self._data = tensor._data
self._fp8_attrs = tensor._fp8_attrs
......@@ -1008,6 +1023,7 @@ class Float8Tensor(QuantizedTensor):
requires_grad=tensor.requires_grad,
device=self._data.device,
)
# pylint: disable=unnecessary-dunder-call
super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor)
if self._transpose is not None:
self._transpose = torch.empty(
......
......@@ -20,6 +20,7 @@ class _DequantizeFunc(torch.autograd.Function):
tensor: QuantizedTensor,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return tensor.dequantize(dtype=dtype)
@staticmethod
......@@ -27,6 +28,7 @@ class _DequantizeFunc(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
return grad, None
......@@ -38,6 +40,7 @@ class _IdentityFunc(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused
tensor: QuantizedTensor,
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
return tensor.detach()
@staticmethod
......@@ -45,6 +48,7 @@ class _IdentityFunc(torch.autograd.Function):
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return grad
......@@ -85,18 +89,23 @@ class QuantizedTensor(torch.Tensor):
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
def float(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return _DequantizeFunc.apply(self, torch.float32)
def bfloat16(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return _DequantizeFunc.apply(self, torch.bfloat16)
def half(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return _DequantizeFunc.apply(self, torch.float16)
def cpu(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return _DequantizeFunc.apply(self).cpu()
def expand_as(self, other: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes
# and access the backward graph (see
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment