Unverified Commit 30c0120b authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix small errors (#2396)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent e1221735
...@@ -24,10 +24,8 @@ from transformer_engine.pytorch import ( ...@@ -24,10 +24,8 @@ from transformer_engine.pytorch import (
MXFP8Quantizer, MXFP8Quantizer,
) )
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.module.base import ( from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
fill_userbuffers_buffer_for_all_gather, from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather
get_cublas_workspace_size_bytes,
)
warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
...@@ -417,10 +415,6 @@ def _main(opts): ...@@ -417,10 +415,6 @@ def _main(opts):
std=opts.std, std=opts.std,
) )
# Allocate cuBLAS workspace
workspace_size = 3 * get_cublas_workspace_size_bytes()
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
# Gather global tensors and calculate reference result (need these first for Fp8 scales) # Gather global tensors and calculate reference result (need these first for Fp8 scales)
if opts.bulk_overlap: if opts.bulk_overlap:
ker_g = torch.transpose(kernel_t, 0, 1) ker_g = torch.transpose(kernel_t, 0, 1)
...@@ -617,7 +611,6 @@ def _main(opts): ...@@ -617,7 +611,6 @@ def _main(opts):
return tex.general_gemm( return tex.general_gemm(
kernel_t_fp8, kernel_t_fp8,
gemm_inp, gemm_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out_quantizer, quantization_params=out_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
...@@ -635,7 +628,6 @@ def _main(opts): ...@@ -635,7 +628,6 @@ def _main(opts):
return tex.general_gemm( return tex.general_gemm(
kernel2_t_fp8, kernel2_t_fp8,
gemm2_inp, gemm2_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out2_quantizer, quantization_params=out2_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
...@@ -648,7 +640,6 @@ def _main(opts): ...@@ -648,7 +640,6 @@ def _main(opts):
return tex.general_gemm( return tex.general_gemm(
kernel_t, kernel_t,
gemm_inp, gemm_inp,
workspace,
out_dtype=torch.bfloat16, out_dtype=torch.bfloat16,
use_split_accumulator=te.module.base._2X_ACC_FPROP, use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub=ub_obj, ub=ub_obj,
......
...@@ -471,6 +471,8 @@ class OffloadSynchronizer: ...@@ -471,6 +471,8 @@ class OffloadSynchronizer:
""" """
if self.num_of_fwds in [None, self.num_layers - 1]: if self.num_of_fwds in [None, self.num_layers - 1]:
# reset the offload synchronizer # reset the offload synchronizer
for layer_id in self.layer_states:
self.layer_states[layer_id].release_all_memory()
self.num_of_fwds = 0 self.num_of_fwds = 0
else: else:
self.num_of_fwds += 1 self.num_of_fwds += 1
......
...@@ -948,7 +948,13 @@ def _all_gather_fp8( ...@@ -948,7 +948,13 @@ def _all_gather_fp8(
if isinstance(inp, Float8Tensor): if isinstance(inp, Float8Tensor):
dtype = inp.dtype dtype = inp.dtype
device = inp.device device = inp.device
# Temporarily ensure rowwise usage for output tensor creation
# since we're gathering rowwise data, not the transpose
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=init_columnwise_usage)
out = quantizer.make_empty(out_shape, dtype=dtype, device=device) out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.set_usage(rowwise=init_rowwise_usage, columnwise=init_columnwise_usage)
elif isinstance(inp, Float8Tensor): elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape) out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty( out._data = torch.empty(
......
...@@ -134,7 +134,9 @@ class MXFP8Quantizer(Quantizer): ...@@ -134,7 +134,9 @@ class MXFP8Quantizer(Quantizer):
columnwise_data = None columnwise_data = None
columnwise_scale_inv = None columnwise_scale_inv = None
if self.columnwise_usage: if self.columnwise_usage:
columnwise_data = torch.empty_like(data, pin_memory=pin_memory) columnwise_data = torch.empty(
shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
columnwise_scale_inv = torch.empty( columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128), round_up_to_nearest_multiple(shape[-1], 128),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment