Commit 546bb548 authored by yuguo's avatar yuguo
Browse files

[DCU] fix bugs

parent 5b6190b2
...@@ -21,7 +21,7 @@ namespace transformer_engine { ...@@ -21,7 +21,7 @@ namespace transformer_engine {
using CompType = double; using CompType = double;
template <typename DataType, typename IndexType> template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, __launch_bounds__(1024)__global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
const IndexType* tokens_per_expert, const IndexType* tokens_per_expert,
int total_num_tokens, int num_experts, int total_num_tokens, int num_experts,
int num_rows, int num_cols, int topk, float coeff, int num_rows, int num_cols, int topk, float coeff,
......
...@@ -38,7 +38,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q ...@@ -38,7 +38,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.fp8 import int8_simulation_fp8, int8_simulation_fp8_tensorwise from transformer_engine.pytorch.fp8 import int8_simulation_fp8, int8_simulation_fp8_tensorwise
tensorwise_int8_check = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE_CHECK", "0"))) int8_simulation_fp8_tensorwise_batched = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE_BATCHED", "0")))
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
"general_grouped_gemm", "general_grouped_gemm",
...@@ -489,7 +489,33 @@ def general_grouped_gemm( ...@@ -489,7 +489,33 @@ def general_grouped_gemm(
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8") raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise: if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
bias = tex.te_general_grouped_gemm(
A,
transa,
B,
transb,
out,
out_dtype,
m_splits,
grad_bias if grad else bias,
bias_dtype,
single_output,
gelu_input, # this is pre_gelu_out
grad, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)) and int8_simulation_fp8_tensorwise and int8_simulation_fp8_tensorwise_batched:
assert len(set(m_splits)) == 1, "Need token pad as same as batchgemm for NVTE_INT8_SIM_FP8_TENSORWISE_BATCHED."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm." assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation" assert TE_DType_To_Torch[out_dtype] is torch.bfloat16 or TE_DType_To_Torch[out_dtype] is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
......
...@@ -17,6 +17,14 @@ __all__ = ["get_cpu_offload_context"] ...@@ -17,6 +17,14 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False CPUOffloadEnabled = False
def get_cpu_offloading():
global CPUOffloadEnabled
return CPUOffloadEnabled
def set_cpu_offloading(cpu_offloading):
global CPUOffloadEnabled
CPUOffloadEnabled = cpu_offloading
def mark_activation_offload(*tensors): def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor.""" """Set the type of the offloading needed for a tensor."""
......
...@@ -55,8 +55,8 @@ def assert_warmed_up(module: torch.nn.Module) -> None: ...@@ -55,8 +55,8 @@ def assert_warmed_up(module: torch.nn.Module) -> None:
" same recipe before exporting." " same recipe before exporting."
) )
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if TORCH_MAJOR == 2 and TORCH_MINOR >= 4 or TORCH_MAJOR > 2: if (TORCH_MAJOR == 2 and TORCH_MINOR >= 4 or TORCH_MAJOR > 2) and not IS_HIP_EXTENSION:
# pylint: disable=unused-import # pylint: disable=unused-import
from .onnx_extensions import ( from .onnx_extensions import (
torch_onnx_gemm_inf_op, torch_onnx_gemm_inf_op,
......
...@@ -204,7 +204,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -204,7 +204,7 @@ class _BatchLinear(torch.autograd.Function):
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms] weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :] main_grads = saved_tensors[4 * ctx.num_gemms :]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in ctx.num_gemms: for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], False) w = torch.nn.Parameter(weights[i], False)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
......
...@@ -50,6 +50,7 @@ from ..tensor.quantized_tensor import ( ...@@ -50,6 +50,7 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
...@@ -286,7 +287,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -286,7 +287,7 @@ class _GroupedLinear(torch.autograd.Function):
if ctx.use_bias: if ctx.use_bias:
grad_output_mats = torch.split(grad_output_view, ctx.m_splits) grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
if recipe.delayed() or recipe.float8_current_scaling() or recipe.mxfp8(): if recipe.delayed() or (recipe.float8_current_scaling() and not IS_HIP_EXTENSION) or recipe.mxfp8():
# Fused bias grad + quantize kernel # Fused bias grad + quantize kernel
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize( grad_biases[i], grad_output[i] = tex.bgrad_quantize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment