Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
...@@ -17,8 +17,10 @@ from transformer_engine.common.recipe import Recipe ...@@ -17,8 +17,10 @@ from transformer_engine.common.recipe import Recipe
from ..fp8 import ( from ..fp8 import (
MXFP8BlockScalingRecipeState, MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState, DelayedScalingRecipeState,
Float8BlockScalingRecipeState,
FP8GlobalStateManager, FP8GlobalStateManager,
RecipeState, RecipeState,
fp8_autocast,
) )
from ..tensor import Quantizer from ..tensor import Quantizer
...@@ -218,6 +220,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -218,6 +220,11 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
if num_quantizers == 0: if num_quantizers == 0:
continue continue
if recipe.float8_block_scaling():
raise NotImplementedError(
"Fusible operations do not support FP8 block scaling recipe"
)
# Construct quantization recipe state # Construct quantization recipe state
recipe_state = RecipeState.create( recipe_state = RecipeState.create(
recipe, recipe,
...@@ -259,8 +266,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -259,8 +266,13 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
continue continue
recipe_state = self._fp8_metas[mode][fp8_meta_key] recipe_state = self._fp8_metas[mode][fp8_meta_key]
need_to_reset_recipe_state = ( need_to_reset_recipe_state = (
recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState))
) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
or (
recipe.float8_block_scaling()
and not isinstance(recipe_state, Float8BlockScalingRecipeState)
)
)
if need_to_reset_recipe_state: if need_to_reset_recipe_state:
self._reset_quantization_recipe_state(recipe=recipe) self._reset_quantization_recipe_state(recipe=recipe)
return return
...@@ -508,7 +520,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -508,7 +520,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def get_extra_state(self) -> torch.Tensor: def get_extra_state(self) -> torch.Tensor:
"""Serialize extra state """Serialize extra state
Contains metadata for FP8 casting. Contains metadata for quantization recipe.
""" """
...@@ -540,21 +552,25 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -540,21 +552,25 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
dst.copy_(src, non_blocking=True) dst.copy_(src, non_blocking=True)
return dst return dst
# Store FP8 state # Store quantizer state if needed
state = {} state = {}
for mode in ("forward", "backward"): for mode in ("forward", "backward"):
# Get state for a given FP8 tensor # Skip if op has no quantizer state
if self.num_quantizers(mode) == 0: if self._fp8_metas is None or self._fp8_metas[mode] is None:
continue continue
fp8_meta = self.get_fp8_meta(mode)
# Quantizer state
fp8_meta = self._fp8_metas[mode]
state[mode] = {} state[mode] = {}
state[mode]["recipe"] = fp8_meta["recipe"]
# Store tensors # Copy tensors to CPU and store
if "scaling_fwd" in fp8_meta: if state[mode]["recipe"].delayed():
if mode == "forward":
state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
if "scaling_bwd" in fp8_meta: if mode == "backward":
state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
...@@ -595,37 +611,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -595,37 +611,37 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device) dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
dst.copy_(src, non_blocking=True) dst.copy_(src, non_blocking=True)
# Load FP8 state # Load quantizer state if needed
for mode in ("forward", "backward"): for mode in ("forward", "backward"):
# Get state for a given FP8 tensor # Skip if checkpoint has no quantizer state
if mode not in state: if mode not in state:
continue continue
if self.num_quantizers(mode) == 0:
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue
# Load extra state # Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state()
fp8_meta = self._fp8_metas[mode]
# Load extra items
fp8_meta["recipe"] = state[mode]["recipe"]
fp8_meta.update(state[mode]["extra_fp8_variables"]) fp8_meta.update(state[mode]["extra_fp8_variables"])
if "amax_history_fwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0)
elif "amax_history_bwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0)
if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta: if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
del fp8_meta["global_fp8_buffer_pos_fwd_recompute"] del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Load tensors # Load tensors
fp8_meta = self.get_fp8_meta(mode) if state[mode]["recipe"].delayed():
if "scaling_fwd" in fp8_meta: if mode == "forward":
fp8_meta_fwd = fp8_meta["scaling_fwd"] copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale)
copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) copy_tensor(
copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history
if "scaling_bwd" in fp8_meta: )
fp8_meta_bwd = fp8_meta["scaling_bwd"] if mode == "backward":
copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale)
copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) copy_tensor(
state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history
)
# Finish CPU-GPU memory transfers # Finish CPU-GPU memory transfers
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -133,10 +133,10 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -133,10 +133,10 @@ class FusedAdam(torch.optim.Optimizer):
# Add constraints to dtypes of states. # Add constraints to dtypes of states.
if master_weights and master_weight_dtype not in [torch.float32, torch.float16]: if master_weights and master_weight_dtype not in [torch.float32, torch.float16]:
raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.") raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.")
if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]: if exp_avg_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.") raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg.")
if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]: if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.") raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg_sq.")
# Currently, capturable mode only supports fp32 master weights and optimizer states. # Currently, capturable mode only supports fp32 master weights and optimizer states.
# The reason is, if the master weights or optimizer states are not in fp32 dtype, # The reason is, if the master weights or optimizer states are not in fp32 dtype,
...@@ -259,6 +259,10 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -259,6 +259,10 @@ class FusedAdam(torch.optim.Optimizer):
scale (torch.Tensor): A FP32 tensor representing the scaling factor. scale (torch.Tensor): A FP32 tensor representing the scaling factor.
""" """
assert unscaled_state.dtype == torch.float32 assert unscaled_state.dtype == torch.float32
if scaled_state.dtype == torch.bfloat16:
scaled_state.copy_(unscaled_state.bfloat16())
return
dtype = self.name_to_dtype_map[state_name] dtype = self.name_to_dtype_map[state_name]
if dtype == torch.uint8: if dtype == torch.uint8:
assert isinstance(scaled_state, Float8Tensor) assert isinstance(scaled_state, Float8Tensor)
...@@ -313,8 +317,11 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -313,8 +317,11 @@ class FusedAdam(torch.optim.Optimizer):
else: else:
assert state[state_name].dtype == torch.float32 assert state[state_name].dtype == torch.float32
unscaled = state[state_name] unscaled = state[state_name]
elif dtype == torch.bfloat16:
assert state[state_name].dtype == torch.bfloat16
unscaled = state[state_name].float()
else: else:
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/bf16/fp32.")
return unscaled return unscaled
def set_scaled_state(self, param, state_name, unscaled_state): def set_scaled_state(self, param, state_name, unscaled_state):
...@@ -329,6 +336,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -329,6 +336,7 @@ class FusedAdam(torch.optim.Optimizer):
and 'master_param`. and 'master_param`.
unscaled_state (torch.Tensor): The original high-precision(FP32) state. unscaled_state (torch.Tensor): The original high-precision(FP32) state.
""" """
store_param_remainders = ( store_param_remainders = (
self.store_param_remainders self.store_param_remainders
and state_name == "master_param" and state_name == "master_param"
......
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
"""MoE Permutaion API""" """MoE Permutaion API"""
import warnings import warnings
from typing import Tuple from typing import Optional, Tuple
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine.pytorch.triton.permutation as triton_permutation import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
__all__ = [ __all__ = [
"moe_permute", "moe_permute",
...@@ -46,16 +48,6 @@ class _moe_permute_index_map(torch.autograd.Function): ...@@ -46,16 +48,6 @@ class _moe_permute_index_map(torch.autograd.Function):
assert inp.size(0) == index.size(0), "Permute not possible" assert inp.size(0) == index.size(0), "Permute not possible"
# Data type check # Data type check
fp8 = isinstance(inp, Float8Tensor)
if fp8:
assert (
inp._quantizer.scale.ndim == 0
), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute."
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
dtype = TE_DType[inp.dtype] dtype = TE_DType[inp.dtype]
if index.dtype != torch.int32: if index.dtype != torch.int32:
warnings.warn( warnings.warn(
...@@ -80,19 +72,9 @@ class _moe_permute_index_map(torch.autograd.Function): ...@@ -80,19 +72,9 @@ class _moe_permute_index_map(torch.autograd.Function):
_moe_permute_index_map.max_expanded_token_num, _moe_permute_index_map.max_expanded_token_num,
) )
if fp8:
permuted_act = Float8Tensor(
data=permuted_act,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=permuted_act.shape,
dtype=fake_dtype,
)
ctx.row_id_map = row_id_map ctx.row_id_map = row_id_map
ctx.num_tokens = index.size(0) ctx.num_tokens = index.size(0)
ctx.topK = index.size(1) ctx.topK = index.size(1)
ctx.fp8 = fp8
return permuted_act, row_id_map return permuted_act, row_id_map
@staticmethod @staticmethod
...@@ -109,30 +91,12 @@ class _moe_permute_index_map(torch.autograd.Function): ...@@ -109,30 +91,12 @@ class _moe_permute_index_map(torch.autograd.Function):
if not permuted_act_grad.is_contiguous(): if not permuted_act_grad.is_contiguous():
permuted_act_grad = permuted_act_grad.contiguous() permuted_act_grad = permuted_act_grad.contiguous()
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
dtype = TE_DType[permuted_act_grad.dtype] dtype = TE_DType[permuted_act_grad.dtype]
act_grad = None act_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
act_grad = tex.moe_permute_bwd( act_grad = tex.moe_permute_bwd(
permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
) )
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv * ctx.topK,
shape=act_grad.shape,
dtype=fake_dtype,
)
return act_grad, None, None, None return act_grad, None, None, None
...@@ -176,13 +140,6 @@ class _moe_unpermute_index_map(torch.autograd.Function): ...@@ -176,13 +140,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
assert row_id_map.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
# Data type check # Data type check
fp8 = isinstance(inp, Float8Tensor)
if fp8:
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
dtype = TE_DType[inp.dtype] dtype = TE_DType[inp.dtype]
if row_id_map.dtype != torch.int32: if row_id_map.dtype != torch.int32:
warnings.warn( warnings.warn(
...@@ -193,17 +150,7 @@ class _moe_unpermute_index_map(torch.autograd.Function): ...@@ -193,17 +150,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
)
ctx.save_for_backward(inp, row_id_map, probs) ctx.save_for_backward(inp, row_id_map, probs)
ctx.fp8 = fp8
return unpermuted_output return unpermuted_output
@staticmethod @staticmethod
...@@ -219,17 +166,7 @@ class _moe_unpermute_index_map(torch.autograd.Function): ...@@ -219,17 +166,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
if not unpermuted_act_grad.is_contiguous(): if not unpermuted_act_grad.is_contiguous():
unpermuted_act_grad = unpermuted_act_grad.contiguous() unpermuted_act_grad = unpermuted_act_grad.contiguous()
if ctx.fp8:
assert isinstance(
unpermuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype
unpermuted_act_grad = unpermuted_act_grad._data
else:
dtype = TE_DType[unpermuted_act_grad.dtype] dtype = TE_DType[unpermuted_act_grad.dtype]
inp, row_id_map, probs = ctx.saved_tensors inp, row_id_map, probs = ctx.saved_tensors
act_grad = None act_grad = None
...@@ -238,14 +175,6 @@ class _moe_unpermute_index_map(torch.autograd.Function): ...@@ -238,14 +175,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
act_grad, prob_grad = tex.moe_unpermute_bwd( act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, dtype, row_id_map, probs unpermuted_act_grad, inp, dtype, row_id_map, probs
) )
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
if not ctx.needs_input_grad[2]: if not ctx.needs_input_grad[2]:
prob_grad = None prob_grad = None
...@@ -282,22 +211,54 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -282,22 +211,54 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)
fp8 = isinstance(inp, Float8Tensor) fp8 = isinstance(inp, QuantizedTensor)
per_tensor_recipe = isinstance(inp, Float8Tensor)
blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor)
mxfp8_recipe = isinstance(inp, MXFP8Tensor)
if fp8: if fp8:
fp8_dtype = inp._fp8_dtype fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype fake_dtype = inp.dtype
# blockwise scaling
if blockwise_recipe:
fp8_scale = inp._rowwise_scale_inv.T.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
# mxfp8 scaling
elif mxfp8_recipe:
fp8_scale = inp._rowwise_scale_inv.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
# per-tensor scaling
elif per_tensor_recipe:
# Kernel does not need scale in per-tensor scaling
fp8_scale = None
scale_hidden_dim = None
fp8_scale_inv = inp._scale_inv
inp = inp._data inp = inp._data
output, permuted_probs = triton_permutation.permute_with_mask_map( else:
raise ValueError("Unsupported FP8 recipe")
else:
fp8_scale = None
fp8_dtype = None
scale_hidden_dim = None
output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map(
inp, inp,
row_id_map, row_id_map,
probs, probs,
fp8_scale,
num_tokens, num_tokens,
num_experts, num_experts,
num_out_tokens, num_out_tokens,
hidden_size, hidden_size,
scale_hidden_dim,
) )
if fp8: if fp8:
if per_tensor_recipe:
output = Float8Tensor( output = Float8Tensor(
data=output, data=output,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
...@@ -305,6 +266,31 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -305,6 +266,31 @@ class _moe_permute_mask_map(torch.autograd.Function):
shape=output.shape, shape=output.shape,
dtype=fake_dtype, dtype=fake_dtype,
) )
elif blockwise_recipe:
output = Float8BlockwiseQTensor(
shape=output.shape,
dtype=fake_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.T.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=None,
is_2D_scaled=False,
requires_grad=output.requires_grad,
)
elif mxfp8_recipe:
output = MXFP8Tensor(
shape=output.shape,
dtype=fake_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=None,
requires_grad=output.requires_grad,
)
ctx.save_for_backward(row_id_map) ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts ctx.num_experts = num_experts
...@@ -327,14 +313,9 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -327,14 +313,9 @@ class _moe_permute_mask_map(torch.autograd.Function):
probs_grad = None probs_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors (row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor) assert not isinstance(
if fp8: permuted_act_grad, QuantizedTensor
fp8_dtype = permuted_act_grad._fp8_dtype ), "The backward of moe_permute does not support FP8."
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
fp8_dtype = None
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
permuted_act_grad, permuted_act_grad,
row_id_map, row_id_map,
...@@ -343,15 +324,6 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -343,15 +324,6 @@ class _moe_permute_mask_map(torch.autograd.Function):
ctx.num_tokens, ctx.num_tokens,
ctx.num_experts, ctx.num_experts,
ctx.hidden_size, ctx.hidden_size,
fp8_dtype,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
shape=act_grad.shape,
dtype=fake_dtype,
) )
if not ctx.needs_input_grad[3]: if not ctx.needs_input_grad[3]:
probs_grad = None probs_grad = None
...@@ -366,8 +338,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -366,8 +338,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
merging_probs: torch.Tensor, merging_probs: Optional[torch.Tensor],
restore_shape: torch.Size, restore_shape: Optional[torch.Size],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if not inp.numel(): if not inp.numel():
...@@ -387,17 +359,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -387,17 +359,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
fp8 = isinstance(inp, Float8Tensor) assert not isinstance(
if fp8: inp, QuantizedTensor
fp8_dtype = inp._fp8_dtype ), "The forward of moe_unpermute does not support FP8."
if not with_probs:
fp8_scale_inv = inp._scale_inv * num_experts
else:
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
fp8_dtype = None
unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
inp, inp,
row_id_map, row_id_map,
...@@ -406,15 +370,6 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -406,15 +370,6 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
num_tokens, num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
fp8_dtype=fp8_dtype,
)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
) )
if with_probs: if with_probs:
...@@ -442,16 +397,44 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -442,16 +397,44 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
else: else:
(row_id_map,) = ctx.saved_tensors (row_id_map,) = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, Float8Tensor) fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor)
mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor)
if fp8: if fp8:
fp8_dtype = unpermuted_act_grad._fp8_dtype fp8_dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype fake_dtype = unpermuted_act_grad.dtype
# per-tensor scaling
if per_tensor_recipe:
# Kernel does not need scale in per-tensor scaling
fp8_scale = None
scale_hidden_dim = None
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data unpermuted_act_grad = unpermuted_act_grad._data
# blockwise scaling
elif blockwise_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
# mxfp8 scaling
elif mxfp8_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
else:
raise ValueError("Unsupported FP8 recipe")
else: else:
scale_hidden_dim = None
fp8_dtype = None fp8_dtype = None
fp8_scale = None
if ctx.with_probs: if ctx.with_probs:
assert (
not fp8
), "The backward of moe_unpermute with merging probs does not support FP8."
act_grad, probs_grad = ( act_grad, probs_grad = (
triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
unpermuted_act_grad, unpermuted_act_grad,
...@@ -462,21 +445,23 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -462,21 +445,23 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx.num_experts, ctx.num_experts,
ctx.num_permuted_tokens, ctx.num_permuted_tokens,
ctx.hidden_size, ctx.hidden_size,
fp8_dtype,
) )
) )
else: else:
act_grad, _ = triton_permutation.permute_with_mask_map( act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map(
unpermuted_act_grad, unpermuted_act_grad,
row_id_map, row_id_map,
None, None,
fp8_scale,
ctx.num_tokens, ctx.num_tokens,
ctx.num_experts, ctx.num_experts,
ctx.num_permuted_tokens, ctx.num_permuted_tokens,
ctx.hidden_size, ctx.hidden_size,
scale_hidden_dim,
) )
if fp8: if fp8:
if per_tensor_recipe:
act_grad = Float8Tensor( act_grad = Float8Tensor(
data=act_grad, data=act_grad,
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
...@@ -484,6 +469,31 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -484,6 +469,31 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
shape=act_grad.shape, shape=act_grad.shape,
dtype=fake_dtype, dtype=fake_dtype,
) )
elif blockwise_recipe:
act_grad = Float8BlockwiseQTensor(
shape=act_grad.shape,
dtype=fake_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.T.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=None,
is_2D_scaled=False,
requires_grad=act_grad.requires_grad,
)
elif mxfp8_recipe:
act_grad = MXFP8Tensor(
shape=act_grad.shape,
dtype=fake_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=None,
requires_grad=act_grad.requires_grad,
)
if not ctx.needs_input_grad[2]: if not ctx.needs_input_grad[2]:
probs_grad = None probs_grad = None
...@@ -568,10 +578,10 @@ def moe_permute_with_probs( ...@@ -568,10 +578,10 @@ def moe_permute_with_probs(
def moe_unpermute( def moe_unpermute(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
merging_probs: torch.Tensor = None, merging_probs: Optional[torch.Tensor] = None,
restore_shape: torch.Tensor = None, restore_shape: Optional[torch.Size] = None,
map_type: str = "mask", map_type: str = "mask",
probs: torch.Tensor = None, probs: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
...@@ -588,7 +598,7 @@ def moe_unpermute( ...@@ -588,7 +598,7 @@ def moe_unpermute(
The tensor of probabilities corresponding to the permuted tokens. If provided, The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities. the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Tensor restore_shape: torch.Size, default = None
The output shape after the unpermute operation. The output shape after the unpermute operation.
map_type: str, default = 'mask' map_type: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute. Type of the routing map tensor. Should be the same as the value passed to moe_permute.
......
...@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype): ...@@ -42,3 +42,27 @@ def _make_module_cast_func(dtype):
torch.nn.Module.float = _make_module_cast_func(torch.float32) torch.nn.Module.float = _make_module_cast_func(torch.float32)
torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.half = _make_module_cast_func(torch.float16)
torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16)
def get_all_tensor_types():
"""
Get all tensor-like types that can be used in TE.
"""
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
)
all_tensor_types = [
torch.Tensor,
torch.nn.Parameter,
Float8Tensor,
Float8TensorBase,
MXFP8Tensor,
MXFP8TensorBase,
Float8BlockwiseQTensor,
Float8BlockwiseQTensorBase,
]
return all_tensor_types
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Mixin class holding data specific for Float8BlockwiseQTensor"""
from __future__ import annotations
import math
from typing import Optional, Dict, Any, Tuple
import torch
from transformer_engine_torch import DType as TE_DType
from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer
class Float8BlockwiseQTensorBase:
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
mixin class. If this class is instantiated directly, it has the same
data, lower CPU overhead, and less functionality. It should only
be instantiated directly for performance-critical internal usage.
"""
_rowwise_data: Optional[torch.Tensor]
_columnwise_data: Optional[torch.Tensor]
_quantizer: Quantizer
_fp8_dtype: TE_DType
_rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled
return instance
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
"rowwise_data": self._rowwise_data,
"rowwise_scale_inv": self._rowwise_scale_inv,
"columnwise_data": self._columnwise_data,
"columnwise_scale_inv": self._columnwise_scale_inv,
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled,
}
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
"""
Prepare the tensor base for saving for backward
This does not clear the tensors currently, because with PP config
that clears the weight cache between micro-batches. If the rowwise
data is not required for backward, this is a possible memory
pessimization, but is consistent with the other quantized tensor
classes.
"""
tensors = [self._rowwise_data, self._columnwise_data]
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1]
return tensors[2:]
def get_data_tensors(self):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:
"""Takes dequantized columnwise data and permutes to a rowwise shape"""
if columnwise_dq.dim() < 2:
return columnwise_dq
permute_dims = list(range(1, columnwise_dq.dim()))
permute_dims.append(0)
return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous()
def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
block_len = 128
q_M, q_K = 1, 1
if self._rowwise_data is not None:
q = self._rowwise_data
scale_inv = self._rowwise_scale_inv
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
orig_shape = q.shape
q = q.reshape(q_M, q_K)
k_tiles, scale_m = scale_inv.shape
if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous()
_, padded_K = q.shape
q_tiled = q.reshape(q_M, k_tiles, block_len)
if scale_m > q_M:
# scale_m is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous()
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_K != q_K:
result = result.reshape(q_M, padded_K)[:, :q_K]
result = result.to(dtype)
if len(orig_shape) == 0:
result = result.reshape([])
else:
result = result.reshape(*orig_shape).contiguous()
if transpose_output:
return self._transpose_dq_columnwise_output(result)
return result
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
"""
block_len = 128
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len)
_, scale_K = scales.shape
if derived_scale_k_shape == scale_K:
return scales
return scales[:, :derived_scale_k_shape].contiguous()
q_M, q_K = 1, 1
if self._rowwise_data is not None:
q = self._rowwise_data
scale_inv = self._rowwise_scale_inv
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
orig_shape = q.shape
q = q.reshape(q_M, q_K)
formatted_scales = format_scale_as_logical_shape(q_K, scale_inv, block_len)
assert len(formatted_scales.shape) == 2
m_tiles, k_tiles = formatted_scales.shape
unpadded_m, unpadded_k = q_M, q_K
m_block_len = block_len
k_block_len = block_len
if q_M % m_block_len != 0 or q_K % k_block_len != 0:
m_pad_amount = (m_block_len - (q_M % m_block_len)) % m_block_len
k_pad_amount = (k_block_len - (q_K % k_block_len)) % k_block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(m_tiles, m_block_len, k_tiles, k_block_len)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * formatted_scales.view(
m_tiles, 1, k_tiles, 1
)
result = result.view(padded_M, padded_K).to(dtype)
if padded_M != unpadded_m or padded_K != unpadded_k:
result = result[:unpadded_m, :unpadded_k]
if len(orig_shape) == 0:
result = result.reshape([])
else:
result = result.reshape(*orig_shape).contiguous()
if transpose_output:
return self._transpose_dq_columnwise_output(result)
return result
def size(self, *args, **kwargs):
# pylint: disable=missing-function-docstring
if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs))
reordered = []
for i in range(1, len(dims)):
reordered.append(dims[i])
reordered.append(dims[0])
return torch.Size(reordered)
def __repr__(self):
if self._rowwise_data is not None:
data = self.dequantize()
descriptor = "rowwise"
else:
data = self.dequantize()
descriptor = "columnwise"
return (
"Float8BlockwiseQTensorBase("
f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}"
)
...@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -27,12 +27,14 @@ class _FromFloat8Func(torch.autograd.Function):
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
dtype = torch_to_transformer_engine_dtype[dtype] te_dtype = torch_to_transformer_engine_dtype[dtype]
# Make sure FP8 data is in expected format # Make sure FP8 data is in expected format
if tensor._data is not None: if tensor._data is not None:
if tensor._data.numel() == 0:
return torch.empty_like(tensor._data, dtype=dtype)
# Cast from FP8 # Cast from FP8
return tex.dequantize(tensor, dtype) return tex.dequantize(tensor, te_dtype)
raise NotImplementedError("Casting back from the transpose not implemented yet!") raise NotImplementedError("Casting back from the transpose not implemented yet!")
...@@ -134,3 +136,11 @@ class Float8TensorBase: ...@@ -134,3 +136,11 @@ class Float8TensorBase:
f"data={self.dequantize()}" f"data={self.dequantize()}"
")" ")"
) )
def _create_transpose(self):
"""Update FP8 transpose cache"""
data = self._data
if not data.is_contiguous():
data = data.contiguous()
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable
import math
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
aten = torch.ops.aten
class Float8BlockQuantizer(Quantizer):
"""Builder class for tensors quantized with current scaling using
NxN quantization tilings to choose scale.
This class is typically used to convert a high-precision tensor
(e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8).
"""
dtype: TE_DType
block_len: int
amax_epsilon: float
force_pow_2_scales: bool
block_scaling_dim: int
def __init__(
self,
fp8_dtype: TE_DType,
*,
rowwise: bool,
columnwise: bool,
amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True,
block_scaling_dim: int = 2,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype
self.block_len = 128
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Update the quantized tensor with data from the source tensor.
This method quantizes the input tensor and stores the result in the destination tensor.
Parameters
----------
src : torch.Tensor
Source tensor containing the data to be quantized
dst : QuantizedTensor
Destination tensor where the quantized data will be stored
noop_flag : Optional[torch.Tensor]
Optional flag tensor indicating whether to skip the quantization operation
Returns
-------
QuantizedTensor
The destination tensor containing the quantized data
Raises
------
AssertionError
If the destination tensor is not a Float8BlockwiseQTensor
"""
assert isinstance(
dst, Float8BlockwiseQTensor
), f"Cannot store quantized blockwise tensor in {type(dst)} type."
# Make sure input is in expected format
if not devices_match(src.device, dst.device):
src = src.to(device=dst.device)
if not src.is_contiguous():
src = src.contiguous()
# Launch cast kernel
tex.quantize(src, self, dst, noop_flag)
dst._fp8_dtype = self.dtype
return dst
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For 2D tensors:
- If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4))
- If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4))
For 1D tensors:
- If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4))
- If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4))
"""
M, K = 1, 1
for i in range(len(shape) - 1):
M *= shape[i]
if len(shape) > 0:
K = shape[-1]
if self.block_scaling_dim == 2:
if columnwise:
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4)
return (outer, inner)
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner)
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
if columnwise:
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(K, 4)
return (outer, inner)
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4)
return (outer, inner)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation.
This method rearranges the dimensions of a tensor to be columnwise,
moving the last dimension to the front and keeping the order of other dimensions.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
"""
if len(shape) == 0:
return tuple()
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
return tuple(colwise_shape)
# TODO(kwyss): With FP8 gather support, we need to implement a
# shape/layout/swizzle check to know whether FP8 gather works
# cleanly by stacking data without aliasing tiles and whether
# the scales also stack on the proper dimensions.
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data"""
if device is None:
device = torch.device("cuda")
# Allocate FP8 data
data = None
scale_inv = None
if self.rowwise_usage:
data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(
scale_shape,
dtype=torch.float32,
device=device,
)
# Allocate FP8 data transpose if needed
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty(
self.get_columnwise_shape(shape), dtype=torch.uint8, device=device
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape,
dtype=torch.float32,
device=device,
)
# Construct FP8 tensor
return Float8BlockwiseQTensor(
shape=shape,
dtype=dtype,
fp8_dtype=self.dtype,
rowwise_data=data,
rowwise_scale_inv=scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2,
requires_grad=requires_grad,
)
def calibrate(self, tensor: torch.Tensor) -> None:
# NOTE: This interface is specific to requirements like delayed scaling
# where state from an estimator influences distribution parameters.
pass
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP8. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
rowwise_data: torch.Tensor
FP8 data in a uint8 tensor matching shape of dequantized tensor.
rowwise_scale_inv: torch.Tensor
FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
columnwise_data: Optional[torch.Tensor]
FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
columnwise_scale_inv: Optional[torch.Tensor]
FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and
holds configuration about quantization and dequantization modes.
"""
def __repr__(self, *, tensor_contents=None):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})"
)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
assert self._quantizer is not None
return self._quantizer
def quantize_(
self,
tensor: torch.Tensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> Float8BlockwiseQTensor:
"""Update FP8 data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
By default the resulting tensor's dtype is the
Float8BlockwiseQTensor's pre-quantized dtype.
"""
if dtype is not None:
dequant_dtype = dtype
else:
dequant_dtype = self.dtype
return super().dequantize(dtype=dequant_dtype)
def detach(self) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
return Float8BlockwiseQTensor.make_like(self)
def update_usage(
self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
assert (
columnwise_usage or rowwise_usage
), "Must retain some data either columnwise or rowwise"
if columnwise_usage and rowwise_usage:
assert (
self._rowwise_data is not None
and self._rowwise_scale_inv is not None
and self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
return
if rowwise_usage:
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise usage."
self._columnwise_data = None
self._columnwise_scale_inv = None
return
if columnwise_usage:
assert (
self._columnwise_data is not None and self._columnwise_scale_inv is not None
), "Cannot update to columnwise usage."
self._rowwise_data = None
self._rowwise_scale_inv = None
return
return
def clone(self) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
rowwise_data = None
if self._rowwise_data is not None:
rowwise_data = self._rowwise_data.detach().clone()
columnwise_data = None
if self._columnwise_data is not None:
columnwise_data = self._columnwise_data.detach().clone()
return _IdentityFunc.apply(
self,
{
"rowwise_data": rowwise_data,
"columnwise_data": columnwise_data,
},
)
def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
return _ReshapeFunc.apply(self, shape)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default:
tensor = args[0]
data = tensor._rowwise_data
if data is None:
# Columnwise data only.
super().__torch_dispatch__(func, types, args, kwargs)
orig_size = data.size()
out_data = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
if orig_size != out_data.size():
raise NotImplementedError(
"Changing shape with view not implemented "
" (scales and columnwise data untouched)."
)
return Float8BlockwiseQTensor.make_like(tensor)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
def contiguous(
self,
memory_format: torch.memory_format = torch.contiguous_format,
) -> Float8BlockwiseQTensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if (
self._rowwise_data is not None
and self._rowwise_data.is_contiguous(memory_format=memory_format)
and (
(self._columnwise_data is None)
or (self._columnwise_data.is_contiguous(memory_format=memory_format))
)
):
return self
raise ValueError("Float8BlockwiseQTensor does not support different memory formats!")
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
@classmethod
def _make_in_reduce_ex(
cls,
shape: torch.Size,
rowwise_data: torch.Tensor,
rowwise_scale_inv: torch.Tensor,
columnwise_data: torch.Tensor,
columnwise_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
dtype: torch.dtype,
quantizer: Quantizer,
is_2D_scaled: bool,
) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return Float8BlockwiseQTensor(
shape=shape,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
fp8_dtype=fp8_dtype,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
dtype=dtype,
quantizer=quantizer,
is_2D_scaled=is_2D_scaled,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
return (
Float8BlockwiseQTensor._make_in_reduce_ex,
(
self.shape,
self._rowwise_data,
self._rowwise_scale_inv,
self._columnwise_data,
self._columnwise_scale_inv,
self._fp8_dtype,
self.dtype,
self._quantizer,
self._is_2D_scaled,
),
)
def _get_data(self) -> Float8BlockwiseQTensor:
"""Get tensor data property"""
return self
@torch.no_grad()
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Just takes FP8 data if setting from a Float8BlockwiseQTensor. Otherwise
casts to FP8.
"""
# Tensor device
new_device = tensor.device if tensor.is_cuda else self.device
def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor):
dst._rowwise_data = src._rowwise_data
dst._columnwise_data = src._columnwise_data
dst._quantizer = src._quantizer
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
# Check that tensor dimensions match
if (
self.size() != tensor.size()
or self.stride() != tensor.stride()
or self.layout != tensor.layout
):
raise ValueError("Invalid tensor for updating Float8BlockwiseQTensor data")
# Just copy FP8 data if other tensor is Float8BlockwiseQTensor
if (
isinstance(tensor, Float8BlockwiseQTensor)
and self.storage_offset() == tensor.storage_offset()
and devices_match(self.device, new_device)
):
_set_from_tensor(self, tensor)
return
if isinstance(tensor, Float8BlockwiseQTensor):
assert tensor._quantizer is not None, "Can't quantize without a quantizer"
quantizer = tensor._quantizer
else:
assert self._quantizer is not None, "Can't quantize without a quantizer"
quantizer = self._quantizer
# Quantize to FP8
quantizer.update_quantized(tensor, self)
# Cast to FP8 when setting Float8BlockwiseQTensor.data
data = property(_get_data, _set_data)
class _ViewFunc(torch.autograd.Function):
"""View function
View the Float8BlockwiseQTensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: Float8BlockwiseQTensor,
shape: Optional[list[int]] = None,
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if ctx is not None:
ctx.shape = tensor.shape
if shape is None:
return tensor
if list(shape) != list(tensor.shape):
raise NotImplementedError("View not implemented.")
return tensor
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
raise NotImplementedError("View bwd not implemented")
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the Float8BlockwiseQTensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: Float8BlockwiseQTensor,
shape: Optional[list[int]] = None,
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if ctx is not None:
ctx.shape = tensor.shape
if shape is None:
return tensor
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(tensor.shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if list(shape) != list(tensor.shape):
raise NotImplementedError("Reshape not implemented yet.")
return tensor
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
raise NotImplementedError("Reshape bwd not implemented yet.")
return grad.view(ctx.shape), None
...@@ -422,13 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -422,13 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
return Float8Tensor.make_like(self) return Float8Tensor.make_like(self)
def _create_transpose(self):
data = self._data
if not data.is_contiguous():
data = data.contiguous()
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False
def update_usage( def update_usage(
self, self,
rowwise_usage: Optional[bool] = None, rowwise_usage: Optional[bool] = None,
......
...@@ -347,6 +347,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -347,6 +347,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_scale_inv: torch.Tensor, columnwise_scale_inv: torch.Tensor,
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
dtype: torch.dtype, dtype: torch.dtype,
shape: torch.shape,
) -> MXFP8Tensor: ) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__ """Build MXFP8Tensor, for use in __reduce__
...@@ -361,10 +362,11 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -361,10 +362,11 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_data=columnwise_data, columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
dtype=dtype, dtype=dtype,
shape=shape,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects""" """Custom pickling"""
return ( return (
MXFP8Tensor._make_in_reduce_ex, MXFP8Tensor._make_in_reduce_ex,
( (
...@@ -374,6 +376,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -374,6 +376,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self._columnwise_scale_inv, self._columnwise_scale_inv,
self._fp8_dtype, self._fp8_dtype,
self.dtype, self.dtype,
self.shape,
), ),
) )
......
...@@ -37,7 +37,8 @@ def prepare_for_saving( ...@@ -37,7 +37,8 @@ def prepare_for_saving(
def restore_from_saved( def restore_from_saved(
tensors: list[Optional[Any]], tensors: list[Optional[Any]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
) -> list[Optional[Any]]: return_saved_tensors: bool = False,
) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]:
"""Recombine the tensor data and metadata during backward pass.""" """Recombine the tensor data and metadata during backward pass."""
tensor_objects = [] tensor_objects = []
for tensor in tensors: for tensor in tensors:
...@@ -47,6 +48,9 @@ def restore_from_saved( ...@@ -47,6 +48,9 @@ def restore_from_saved(
else: else:
saved_tensors = tensor.restore_from_saved(saved_tensors) saved_tensors = tensor.restore_from_saved(saved_tensors)
tensor_objects.append(tensor) tensor_objects.append(tensor)
if return_saved_tensors:
return tensor_objects, saved_tensors
return tensor_objects return tensor_objects
...@@ -113,7 +117,11 @@ class Quantizer(abc.ABC): ...@@ -113,7 +117,11 @@ class Quantizer(abc.ABC):
"""Quantize tensor in-place""" """Quantize tensor in-place"""
def quantize( def quantize(
self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None self,
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override
) -> QuantizedTensor: ) -> QuantizedTensor:
"""Quantize tensor""" """Quantize tensor"""
if out is not None: if out is not None:
......
...@@ -39,7 +39,9 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): ...@@ -39,7 +39,9 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet") raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet")
def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, group): def cast_master_weights_to_fp8(
model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None
):
r"""Helper function to cast master weights to FP8 primary weights. r"""Helper function to cast master weights to FP8 primary weights.
This is intended for use with ZeRO/FSDP. Each rank has a shard of This is intended for use with ZeRO/FSDP. Each rank has a shard of
...@@ -56,14 +58,23 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro ...@@ -56,14 +58,23 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
should be updated. should be updated.
group : The distributed group to do amax reduction. Typically it's the data parallel group : The distributed group to do amax reduction. Typically it's the data parallel
group. group.
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights.
""" """
delayed_scaling_params = [] delayed_scaling_params = []
current_scaling_params = [] current_scaling_params = []
for model_weight, master_weight, start_offset in zip( if fsdp_shard_model_weights is None:
model_weights, master_weights, start_offsets use_fsdp_shard_model_weights = False
fsdp_shard_model_weights = [None] * len(model_weights)
else:
use_fsdp_shard_model_weights = True
for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip(
model_weights, master_weights, start_offsets, fsdp_shard_model_weights
): ):
# Clear `_high_precision_init_val` of model_weight automatically. # Clear `_high_precision_init_val` of model_weight automatically.
# - Master weights are initialized from model weights, if we use fp8 primary weights to # - Master weights are initialized from model weights, if we use fp8 primary weights to
...@@ -89,9 +100,13 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro ...@@ -89,9 +100,13 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
quantizer = model_weight._get_quantizer() quantizer = model_weight._get_quantizer()
if isinstance(quantizer, Float8Quantizer): if isinstance(quantizer, Float8Quantizer):
delayed_scaling_params.append((model_weight, master_weight, start_offset)) delayed_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, Float8CurrentScalingQuantizer): elif isinstance(quantizer, Float8CurrentScalingQuantizer):
current_scaling_params.append((model_weight, master_weight, start_offset)) current_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, MXFP8Quantizer): elif isinstance(quantizer, MXFP8Quantizer):
raise NotImplementedError( raise NotImplementedError(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet" "cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
...@@ -102,12 +117,16 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro ...@@ -102,12 +117,16 @@ def cast_master_weights_to_fp8(model_weights, master_weights, start_offsets, gro
) )
if len(delayed_scaling_params) > 0: if len(delayed_scaling_params) > 0:
_cast_master_weights_to_fp8_delayed_scaling(delayed_scaling_params, group) _cast_master_weights_to_fp8_delayed_scaling(
delayed_scaling_params, group, use_fsdp_shard_model_weights
)
if len(current_scaling_params) > 0: if len(current_scaling_params) > 0:
_cast_master_weights_to_fp8_current_scaling(current_scaling_params, group) _cast_master_weights_to_fp8_current_scaling(
current_scaling_params, group, use_fsdp_shard_model_weights
)
def _cast_master_weights_to_fp8_delayed_scaling(params, group): def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False):
r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. r"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
Parameters Parameters
...@@ -116,13 +135,14 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): ...@@ -116,13 +135,14 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
indicating the starting index of the master weight in the model weight. indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel group : The distributed group to do amax reduction. Typically it's the data parallel
group. group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
""" """
# Collect amaxes to do reduce-max among dp group. # Collect amaxes to do reduce-max among dp group.
# Collect scales and scale_invs to update scale_invs of the fp8 weights. # Collect scales and scale_invs to update scale_invs of the fp8 weights.
amaxes, scales, scale_invs = [], [], [] amaxes, scales, scale_invs = [], [], []
for model_weight, master_weight, start_offset in params: for model_weight, master_weight, start_offset, shard_model_weight_raw in params:
# Reset transpose cache for all model weights. # Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap # We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated # the all-gather of model weights and forward process, so the model weight is not updated
...@@ -148,6 +168,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): ...@@ -148,6 +168,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
# master_weight may be smaller than model_weight because it could be distributed across # master_weight may be smaller than model_weight because it could be distributed across
# multiple ranks. So we need to create a dummy weight using the raw data from model_weight. # multiple ranks. So we need to create a dummy weight using the raw data from model_weight.
if not use_fsdp_shard_model_weights:
shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset] shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset]
shard_model_weight_fp8 = quantizer.create_tensor_from_data( shard_model_weight_fp8 = quantizer.create_tensor_from_data(
shard_model_weight_raw.view(1, -1), shard_model_weight_raw.view(1, -1),
...@@ -187,7 +208,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group): ...@@ -187,7 +208,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group):
) )
def _cast_master_weights_to_fp8_current_scaling(params, group): def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False):
r"""Helper function to cast master weights to FP8 primary weights for current scaling. r"""Helper function to cast master weights to FP8 primary weights for current scaling.
Parameters Parameters
...@@ -196,6 +217,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -196,6 +217,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
indicating the starting index of the master weight in the model weight. indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel group : The distributed group to do amax reduction. Typically it's the data parallel
group. group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
""" """
# Parameter attributes # Parameter attributes
...@@ -220,7 +242,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -220,7 +242,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# amaxes in a contiguous buffer. If the master weight is None, the corresponding amax # amaxes in a contiguous buffer. If the master weight is None, the corresponding amax
# will be set to 0. # will be set to 0.
# --------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------
for (model_weight, master_weight, _), amax in zip(params, amaxes): for (model_weight, master_weight, _, _), amax in zip(params, amaxes):
# Make sure all the model weights have the same numerical options. # Make sure all the model weights have the same numerical options.
quantizer = model_weight._get_quantizer() quantizer = model_weight._get_quantizer()
...@@ -261,7 +283,9 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -261,7 +283,9 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# --------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8. # Step 4: Cast master weights to FP8.
# --------------------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------------------
for (model_weight, master_weight, start_offset), scale in zip(params, scales): for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales
):
# Reset transpose cache for all model weights. # Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap # We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated # the all-gather of model weights and forward process, so the model weight is not updated
...@@ -275,10 +299,18 @@ def _cast_master_weights_to_fp8_current_scaling(params, group): ...@@ -275,10 +299,18 @@ def _cast_master_weights_to_fp8_current_scaling(params, group):
# Cast master weight to FP8 # Cast master weight to FP8
end_offset = start_offset + master_weight.numel() end_offset = start_offset + master_weight.numel()
if not use_fsdp_shard_model_weights:
model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset] model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset]
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale=scale, scale=scale,
amax=torch.Tensor(), amax=torch.Tensor(),
fp8_dtype=model_weight._fp8_dtype, fp8_dtype=model_weight._fp8_dtype,
) )
if use_fsdp_shard_model_weights and not isinstance(model_weight_fragment, Float8Tensor):
# NOTE: The fsdp shard model weight may be a unit8 tensor instead of
# a float8 tensor. We should handle this situation properly.
model_weight_fragment = quantizer.create_tensor_from_data(
model_weight_fragment.view(-1),
model_weight.dtype,
)
quantizer.update_quantized(master_weight, model_weight_fragment) quantizer.update_quantized(master_weight, model_weight_fragment)
...@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -11,6 +11,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention import ( from transformer_engine.pytorch.attention import (
MultiheadAttention, MultiheadAttention,
) )
...@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import ( ...@@ -33,6 +34,7 @@ from transformer_engine.pytorch.constants import (
dist_group_type, dist_group_type,
) )
from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
...@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module):
head size. Note that these formats are very closely head size. Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention` related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules. and `DotProductAttention` modules.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -277,6 +281,7 @@ class TransformerLayer(torch.nn.Module):
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd", attn_input_format: str = "sbhd",
name: str = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -336,6 +341,8 @@ class TransformerLayer(torch.nn.Module):
self.attn_input_format = attn_input_format self.attn_input_format = attn_input_format
self.name = name
attention_args = ( attention_args = (
hidden_size, hidden_size,
num_attention_heads, num_attention_heads,
...@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -376,6 +383,7 @@ class TransformerLayer(torch.nn.Module):
return_bias=not self.parallel_attention_mlp, return_bias=not self.parallel_attention_mlp,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".self_attention" if name is not None else None,
) )
if layer_type == "decoder": if layer_type == "decoder":
...@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -389,6 +397,7 @@ class TransformerLayer(torch.nn.Module):
return_bias=True, return_bias=True,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".inter_attention" if name is not None else None,
) )
# LayerNorm -> activation(Linear + Bias) -> Linear # LayerNorm -> activation(Linear + Bias) -> Linear
...@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -423,6 +432,7 @@ class TransformerLayer(torch.nn.Module):
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".layernorm_mlp" if name is not None else None,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
...@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -679,6 +689,9 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)" ), "Encoder-decoder attention mask must be boolean tensor(s)"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
......
...@@ -10,15 +10,6 @@ import torch ...@@ -10,15 +10,6 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from transformer_engine_torch import DType as TE_DType
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
e5m2_data_type = tl.float8e5b16
e4m3_data_type = tl.float8e4b8
else:
e5m2_data_type = tl.float8e5
e4m3_data_type = tl.float8e4nv
@triton.jit @triton.jit
def _row_id_map_pass_1_kernel( def _row_id_map_pass_1_kernel(
...@@ -123,11 +114,14 @@ def _permute_kernel( ...@@ -123,11 +114,14 @@ def _permute_kernel(
output_ptr, output_ptr,
row_id_map_ptr, row_id_map_ptr,
probs_ptr, probs_ptr,
scale_ptr,
permuted_probs_ptr, permuted_probs_ptr,
permuted_scale_ptr,
# sizes # sizes
num_tokens, num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
scale_hidden_dim,
# strides # strides
stride_input_token, stride_input_token,
stride_input_hidden, stride_input_hidden,
...@@ -135,9 +129,14 @@ def _permute_kernel( ...@@ -135,9 +129,14 @@ def _permute_kernel(
stride_output_hidden, stride_output_hidden,
stride_probs_token, stride_probs_token,
stride_probs_expert, stride_probs_expert,
stride_scale_token,
stride_scale_hidden,
stride_permuted_probs_token, stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# metas # metas
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -147,11 +146,21 @@ def _permute_kernel( ...@@ -147,11 +146,21 @@ def _permute_kernel(
mask = cur_off < hidden_size mask = cur_off < hidden_size
input_off = pid * stride_input_token + cur_off * stride_input_hidden input_off = pid * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask) inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
for expert_idx in range(num_experts): for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1: if dst_row != -1:
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
tl.store(output_ptr + output_off, inp, mask=mask) tl.store(output_ptr + output_off, inp, mask=mask)
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token
+ cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS: if PERMUTE_PROBS:
if cur_pos == 0: if cur_pos == 0:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
...@@ -180,10 +189,12 @@ def permute_with_mask_map( ...@@ -180,10 +189,12 @@ def permute_with_mask_map(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
probs: torch.Tensor, probs: torch.Tensor,
scale: torch.Tensor,
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
num_out_tokens: int, num_out_tokens: int,
hidden_size: int, hidden_size: int,
scale_hidden_dim: int,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
...@@ -191,26 +202,42 @@ def permute_with_mask_map( ...@@ -191,26 +202,42 @@ def permute_with_mask_map(
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
else: else:
permuted_probs = None permuted_probs = None
if scale is not None:
permuted_scale = torch.empty(
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
)
else:
permuted_scale = None
grid = (num_tokens,) grid = (num_tokens,)
_permute_kernel[grid]( _permute_kernel[grid](
inp, inp,
output, output,
row_id_map, row_id_map,
probs, probs,
scale,
permuted_probs, permuted_probs,
permuted_scale,
num_tokens, num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
scale_hidden_dim,
inp.stride(0), inp.stride(0),
inp.stride(1), inp.stride(1),
output.stride(0), output.stride(0),
output.stride(1), output.stride(1),
probs.stride(0) if probs is not None else None, probs.stride(0) if probs is not None else None,
probs.stride(1) if probs is not None else None, probs.stride(1) if probs is not None else None,
scale.stride(0) if scale is not None else None,
scale.stride(1) if scale is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None,
permuted_scale.stride(0) if permuted_scale is not None else None,
permuted_scale.stride(1) if permuted_scale is not None else None,
PERMUTE_PROBS=probs is not None, PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None,
) )
return output, permuted_probs return output, permuted_scale, permuted_probs
@triton.jit @triton.jit
...@@ -239,18 +266,9 @@ def _unpermute_kernel( ...@@ -239,18 +266,9 @@ def _unpermute_kernel(
# metas # metas
WITH_MERGING_PROBS: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
if FP8_DTYPE == "e5m2":
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
data_type = input_ptr.dtype.element_ty data_type = input_ptr.dtype.element_ty
assert FP8_DTYPE is None
compute_type = tl.float32 compute_type = tl.float32
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -264,8 +282,6 @@ def _unpermute_kernel( ...@@ -264,8 +282,6 @@ def _unpermute_kernel(
if src_row != -1: if src_row != -1:
input_off = src_row * stride_input_token + current_offset * stride_input_hidden input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask) inp = tl.load(input_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type) inp = inp.to(compute_type)
if WITH_MERGING_PROBS: if WITH_MERGING_PROBS:
merging_prob_off = ( merging_prob_off = (
...@@ -286,13 +302,6 @@ def _unpermute_kernel( ...@@ -286,13 +302,6 @@ def _unpermute_kernel(
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
else: else:
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0)
if FP8_DTYPE is not None:
if not WITH_MERGING_PROBS:
# Directly adding these value may cause overflow for fp8, we scale it here.
# The outside fp8_scale_inv is also scaled in the meantime.
accumulator /= num_experts
accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
else:
accumulator = accumulator.to(data_type) accumulator = accumulator.to(data_type)
output_off = pid * stride_output_token + current_offset * stride_output_hidden output_off = pid * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask) tl.store(output_ptr + output_off, accumulator, mask=mask)
...@@ -322,15 +331,8 @@ def unpermute_with_mask_map( ...@@ -322,15 +331,8 @@ def unpermute_with_mask_map(
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
fp8_dtype: TE_DType,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if permuted_probs is not None: if permuted_probs is not None:
unpermuted_probs = torch.empty( unpermuted_probs = torch.empty(
...@@ -360,7 +362,6 @@ def unpermute_with_mask_map( ...@@ -360,7 +362,6 @@ def unpermute_with_mask_map(
unpermuted_probs.stride(1) if unpermuted_probs is not None else None, unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
WITH_MERGING_PROBS=merging_probs is not None, WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None, PERMUTE_PROBS=permuted_probs is not None,
FP8_DTYPE=fp8_dtype,
) )
return output, unpermuted_probs return output, unpermuted_probs
...@@ -390,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -390,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_grad_token, stride_merging_probs_grad_token,
stride_merging_probs_grad_expert, stride_merging_probs_grad_expert,
# metas # metas
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
if FP8_DTYPE == "e5m2":
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
data_type = fwd_output_grad_ptr.dtype.element_ty data_type = fwd_output_grad_ptr.dtype.element_ty
assert FP8_DTYPE is None
compute_type = tl.float32 compute_type = tl.float32
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -418,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -418,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ current_offset * stride_fwd_output_grad_hidden + current_offset * stride_fwd_output_grad_hidden
) )
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type) inp = inp.to(compute_type)
merging_prob_off = ( merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
...@@ -427,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -427,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob output = inp * merging_prob
output = output.to(data_type) output = output.to(data_type)
if FP8_DTYPE is not None:
output = output.to(pytorch_tensor_dtype, bitcast=True)
output_off = ( output_off = (
dst_row * stride_fwd_input_grad_token dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden + current_offset * stride_fwd_input_grad_hidden
...@@ -439,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -439,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
) )
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
if FP8_DTYPE is not None:
fwd_input = fwd_input.to(data_type, bitcast=True)
prob_grad_accum += fwd_input.to(compute_type) * inp prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
...@@ -481,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -481,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_experts: int, num_experts: int,
num_out_tokens: int, num_out_tokens: int,
hidden_size: int, hidden_size: int,
fp8_dtype: TE_DType,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
act_grad = torch.empty( act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
) )
...@@ -517,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -517,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs.stride(1), merging_probs.stride(1),
merging_probs_grad.stride(0), merging_probs_grad.stride(0),
merging_probs_grad.stride(1), merging_probs_grad.stride(1),
fp8_dtype,
) )
return act_grad, merging_probs_grad return act_grad, merging_probs_grad
......
...@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple ...@@ -11,6 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from .tensor.quantized_tensor import QuantizedTensor from .tensor.quantized_tensor import QuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
...@@ -354,6 +355,19 @@ def round_up_to_nearest_multiple(value, multiple): ...@@ -354,6 +355,19 @@ def round_up_to_nearest_multiple(value, multiple):
return ((value + multiple - 1) // multiple) * multiple return ((value + multiple - 1) // multiple) * multiple
def needs_quantized_gemm(obj, rowwise=True):
"""Used to check if obj will need quantized gemm or normal gemm."""
if isinstance(obj, DebugQuantizedTensor):
return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck
torch.Tensor,
torch.nn.Parameter,
]
return type(obj) not in [
torch.Tensor,
torch.nn.Parameter,
] # pylint: disable=unidiomatic-typecheck
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _nvtx_enabled() -> bool: def _nvtx_enabled() -> bool:
"""Check if NVTX range profiling is enabled""" """Check if NVTX range profiling is enabled"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment