Unverified Commit 30c0120b authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix small errors (#2396)



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent e1221735
......@@ -24,10 +24,8 @@ from transformer_engine.pytorch import (
MXFP8Quantizer,
)
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -417,10 +415,6 @@ def _main(opts):
std=opts.std,
)
# Allocate cuBLAS workspace
workspace_size = 3 * get_cublas_workspace_size_bytes()
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
if opts.bulk_overlap:
ker_g = torch.transpose(kernel_t, 0, 1)
......@@ -617,7 +611,6 @@ def _main(opts):
return tex.general_gemm(
kernel_t_fp8,
gemm_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
......@@ -635,7 +628,6 @@ def _main(opts):
return tex.general_gemm(
kernel2_t_fp8,
gemm2_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out2_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
......@@ -648,7 +640,6 @@ def _main(opts):
return tex.general_gemm(
kernel_t,
gemm_inp,
workspace,
out_dtype=torch.bfloat16,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub=ub_obj,
......
......@@ -471,6 +471,8 @@ class OffloadSynchronizer:
"""
if self.num_of_fwds in [None, self.num_layers - 1]:
# reset the offload synchronizer
for layer_id in self.layer_states:
self.layer_states[layer_id].release_all_memory()
self.num_of_fwds = 0
else:
self.num_of_fwds += 1
......
......@@ -948,7 +948,13 @@ def _all_gather_fp8(
if isinstance(inp, Float8Tensor):
dtype = inp.dtype
device = inp.device
# Temporarily ensure rowwise usage for output tensor creation
# since we're gathering rowwise data, not the transpose
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=init_columnwise_usage)
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.set_usage(rowwise=init_rowwise_usage, columnwise=init_columnwise_usage)
elif isinstance(inp, Float8Tensor):
out = inp.make_like(inp, shape=out_shape)
out._data = torch.empty(
......
......@@ -134,7 +134,9 @@ class MXFP8Quantizer(Quantizer):
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty_like(data, pin_memory=pin_memory)
columnwise_data = torch.empty(
shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
)
columnwise_scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(shape[-1], 128),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment