Unverified Commit ac886c35 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Fix QuantizedTensorBase -> QuantizedTensorStorage (#2226)



Fix QuantizedTensorBase -> QuantizedTensorStorage
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
parent b0d562d8
......@@ -25,7 +25,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensorBase,
QuantizedTensorStorage,
prepare_for_saving,
restore_from_saved,
)
......@@ -1312,7 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function):
# d_out is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase):
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage):
d_out = ctx.dO_quantizer(d_out)
if not ctx.use_FAv2_bwd:
d_out._data = d_out._data.contiguous()
......@@ -1479,7 +1479,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.dP_quantizer,
)
else:
if isinstance(d_out, QuantizedTensorBase):
if isinstance(d_out, QuantizedTensorStorage):
d_out = d_out.dequantize(dtype=ctx.nominal_dtype)
dqkv_te_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16
......
......@@ -21,7 +21,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.jit import jit_fuser
from transformer_engine.pytorch.constants import (
dist_group_type,
......@@ -1823,7 +1823,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# dout is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorBase):
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage):
dout = ctx.dO_quantizer(dout)
if ctx.use_fused_attention:
dout._data = dout._data.contiguous()
......@@ -1997,7 +1997,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy()
dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))
else:
if isinstance(dout, QuantizedTensorBase):
if isinstance(dout, QuantizedTensorStorage):
dout = dout.dequantize(dtype=bwd_nominal_dtype)
dq_buffer = torch.empty_like(q)
p2p_comm_buffers = [
......@@ -3396,7 +3396,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
if ctx.fp8:
if ctx.use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"]
if not isinstance(dout, QuantizedTensorBase):
if not isinstance(dout, QuantizedTensorStorage):
dout = ctx.dO_quantizer(dout)
dout_fp8 = dout
dqkv_te_dtype = dout._fp8_dtype
......@@ -3409,7 +3409,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else:
assert False, "FP8 is only supported with Fused Attention!"
else:
if isinstance(dout, QuantizedTensorBase):
if isinstance(dout, QuantizedTensorStorage):
dout = dout.dequantize(dtype=bwd_nominal_dtype)
if ctx.use_fused_attention:
fp8_meta_kwargs = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment