Unverified Commit ac886c35 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Fix QuantizedTensorBase -> QuantizedTensorStorage (#2226)



Fix QuantizedTensorBase -> QuantizedTensorStorage
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
parent b0d562d8
...@@ -25,7 +25,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -25,7 +25,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensorBase, QuantizedTensorStorage,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
...@@ -1312,7 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1312,7 +1312,7 @@ class FusedAttnFunc(torch.autograd.Function):
# d_out is expected to be in FP8 if is_output_fp8=True, # d_out is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation # but in the case it's not, convert it to FP8 before any operation
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase): if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage):
d_out = ctx.dO_quantizer(d_out) d_out = ctx.dO_quantizer(d_out)
if not ctx.use_FAv2_bwd: if not ctx.use_FAv2_bwd:
d_out._data = d_out._data.contiguous() d_out._data = d_out._data.contiguous()
...@@ -1479,7 +1479,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1479,7 +1479,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.dP_quantizer, ctx.dP_quantizer,
) )
else: else:
if isinstance(d_out, QuantizedTensorBase): if isinstance(d_out, QuantizedTensorStorage):
d_out = d_out.dequantize(dtype=ctx.nominal_dtype) d_out = d_out.dequantize(dtype=ctx.nominal_dtype)
dqkv_te_dtype = TE_DType[d_out.dtype] dqkv_te_dtype = TE_DType[d_out.dtype]
# q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16
......
...@@ -21,7 +21,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -21,7 +21,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.jit import jit_fuser
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
dist_group_type, dist_group_type,
...@@ -1823,7 +1823,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1823,7 +1823,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# dout is expected to be in FP8 if is_output_fp8=True, # dout is expected to be in FP8 if is_output_fp8=True,
# but in the case it's not, convert it to FP8 before any operation # but in the case it's not, convert it to FP8 before any operation
if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorBase): if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage):
dout = ctx.dO_quantizer(dout) dout = ctx.dO_quantizer(dout)
if ctx.use_fused_attention: if ctx.use_fused_attention:
dout._data = dout._data.contiguous() dout._data = dout._data.contiguous()
...@@ -1997,7 +1997,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1997,7 +1997,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy()
dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,))
else: else:
if isinstance(dout, QuantizedTensorBase): if isinstance(dout, QuantizedTensorStorage):
dout = dout.dequantize(dtype=bwd_nominal_dtype) dout = dout.dequantize(dtype=bwd_nominal_dtype)
dq_buffer = torch.empty_like(q) dq_buffer = torch.empty_like(q)
p2p_comm_buffers = [ p2p_comm_buffers = [
...@@ -3396,7 +3396,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3396,7 +3396,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
if ctx.use_fused_attention: if ctx.use_fused_attention:
fused_attn_backend = FusedAttnBackend["FP8"] fused_attn_backend = FusedAttnBackend["FP8"]
if not isinstance(dout, QuantizedTensorBase): if not isinstance(dout, QuantizedTensorStorage):
dout = ctx.dO_quantizer(dout) dout = ctx.dO_quantizer(dout)
dout_fp8 = dout dout_fp8 = dout
dqkv_te_dtype = dout._fp8_dtype dqkv_te_dtype = dout._fp8_dtype
...@@ -3409,7 +3409,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3409,7 +3409,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else: else:
assert False, "FP8 is only supported with Fused Attention!" assert False, "FP8 is only supported with Fused Attention!"
else: else:
if isinstance(dout, QuantizedTensorBase): if isinstance(dout, QuantizedTensorStorage):
dout = dout.dequantize(dtype=bwd_nominal_dtype) dout = dout.dequantize(dtype=bwd_nominal_dtype)
if ctx.use_fused_attention: if ctx.use_fused_attention:
fp8_meta_kwargs = {} fp8_meta_kwargs = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment