Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
......@@ -101,7 +101,7 @@ if __name__ == "__main__":
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy", "praxis"],
tests_require=["numpy"],
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
......
......@@ -89,7 +89,11 @@ def generate_pspec(logical_axis_names):
Convert logical axes to PartitionSpec
"""
rules = get_sharding_map_logic_axis_to_mesh_axis()
mesh_axis_names = [rules[name] for name in logical_axis_names]
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = []
for name in logical_axis_names:
axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name)
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
return pspec
......@@ -112,7 +116,7 @@ def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: t
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
"""
if logical_axis_names is None:
if not logical_axis_names:
return x
assert len(x.shape) == len(logical_axis_names)
......@@ -315,3 +319,25 @@ class ShardingType(Enum):
TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if not logical_axes:
logical_axes = (None,) * ndim
elif len(logical_axes) < ndim:
logical_axes = logical_axes + (None,) * (ndim - len(logical_axes))
assert len(logical_axes) == ndim
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
return non_contracting_logical_axes
......@@ -20,6 +20,7 @@ import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.utils import (
get_cudnn_version,
nvtx_range_pop,
......@@ -81,6 +82,7 @@ import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
from .cpu_offload import mark_activation_offload
# Setup Attention Logging
......@@ -618,7 +620,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank = get_distributed_rank(cp_group)
send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0"))
causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type
......@@ -1566,7 +1568,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2)
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0"))
q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
......@@ -4323,10 +4325,9 @@ class FlashAttention(torch.nn.Module):
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
mark_activation_offload(
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
)
with self.attention_dropout_ctx():
# | API | use cases
......@@ -4728,12 +4729,9 @@ class FusedAttnFunc(torch.autograd.Function):
else:
tensor_list = [q, k, v, out_save]
tensor_list.extend(aux_ctx_tensors)
qkv_layout = "sbhd_sbhd_sbhd"
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
mark_activation_offload(*tensor_list)
mark_activation_offload(*aux_ctx_tensors)
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
......@@ -6482,6 +6480,8 @@ class MultiheadAttention(torch.nn.Module):
equal length. Please note that these formats do not reflect how
tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory.
For that, please use `get_qkv_layout` to gain the layout information.
name: str, default = `None`
name of the module, currently used for debugging purposes.
Parallelism parameters
----------------------
......@@ -6560,6 +6560,7 @@ class MultiheadAttention(torch.nn.Module):
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd",
name: str = None,
) -> None:
super().__init__()
......@@ -6611,6 +6612,8 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"tp_group": tp_group,
......@@ -6651,6 +6654,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_name="qkv",
name=name + ".layernorm_linear_qkv" if name is not None else None,
**common_gemm_kwargs,
)
else:
......@@ -6662,6 +6666,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=parameters_split,
name=name + ".linear_qkv" if name is not None else None,
**common_gemm_kwargs,
)
elif self.attention_type == "cross":
......@@ -6683,6 +6688,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_name="qkv",
name=name + ".layernorm_linear_q" if name is not None else None,
**common_gemm_kwargs,
)
else:
......@@ -6693,6 +6699,7 @@ class MultiheadAttention(torch.nn.Module):
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
name=name + ".linear_q" if name is not None else None,
**common_gemm_kwargs,
)
self.key_value = Linear(
......@@ -6703,6 +6710,7 @@ class MultiheadAttention(torch.nn.Module):
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("key", "value") if not fuse_qkv_params else None,
name=name + ".linear_kv" if name is not None else None,
**common_gemm_kwargs,
)
......@@ -6732,6 +6740,7 @@ class MultiheadAttention(torch.nn.Module):
ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag,
ub_name="proj",
name=name + ".proj" if name is not None else None,
**common_gemm_kwargs,
)
......@@ -6922,6 +6931,9 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# =================================================
# Pre-allocate memory for key-value cache for inference
# =================================================
......
......@@ -24,6 +24,12 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
"""
This is a map: int -> torch.dtype
Used for resolving cuda extension types to torch.
Has one to one mapping with enum in
transformer_engine.h
"""
TE_DType_To_Torch = {
tex.DType.kByte: torch.uint8,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
......
......@@ -9,11 +9,11 @@ import os
import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec, get_sm_count
from ..utils import get_sm_count
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
__all__ = [
"general_gemm",
......@@ -28,46 +28,6 @@ def _empty_tensor() -> torch.Tensor:
return torch.Tensor().cuda()
def swizzle_inputs(A: torch.Tensor, B: torch.Tensor, layout: str):
"""Swizzle gemm inputs and return original scaling factor inverses."""
if not isinstance(A, MXFP8TensorBase) or not isinstance(B, MXFP8TensorBase):
return None
original_scale_inverses = (
A._rowwise_scale_inv,
A._columnwise_scale_inv,
B._rowwise_scale_inv,
B._columnwise_scale_inv,
)
if layout[0] == "T":
A._rowwise_scale_inv = tex.rowwise_swizzle(A._rowwise_data, A._rowwise_scale_inv)
else:
A._columnwise_scale_inv = tex.columnwise_swizzle(
A._columnwise_data, A._columnwise_scale_inv
)
if layout[1] == "N":
B._rowwise_scale_inv = tex.rowwise_swizzle(B._rowwise_data, B._rowwise_scale_inv)
else:
B._columnwise_scale_inv = tex.columnwise_swizzle(
B._columnwise_data, B._columnwise_scale_inv
)
return original_scale_inverses
def reset_swizzled_inputs(A, B, scale_inverses):
"""Reset the swizzled scale inverses after GEMM."""
if scale_inverses is not None:
(
A._rowwise_scale_inv,
A._columnwise_scale_inv,
B._rowwise_scale_inv,
B._columnwise_scale_inv,
) = scale_inverses
def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
......@@ -110,9 +70,20 @@ def general_gemm(
if not out.is_contiguous():
raise ValueError("Output tensor is not contiguous.")
debug_quantizer = None
if isinstance(quantization_params, DebugQuantizer):
debug_quantizer = quantization_params
quantization_params = quantization_params.parent_quantizer
A = A.get_tensor(not transa)
B = B.get_tensor(transb)
# Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator = True
args = (
A,
transa, # transa
......@@ -138,9 +109,10 @@ def general_gemm(
"bulk_overlap": bulk_overlap,
}
original_scale_inverses = swizzle_inputs(A, B, layout)
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
reset_swizzled_inputs(A, B, original_scale_inverses)
if debug_quantizer is not None:
out = debug_quantizer.process_gemm_output(out)
return out, bias_grad, gelu_input, extra_output
......@@ -170,14 +142,6 @@ def general_grouped_gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if isinstance(A[0], Float8TensorBase):
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
......
......@@ -16,18 +16,22 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False
def set_offloading_param(tensor, param_name, value):
def mark_activation_offload(*tensors):
"""Set the type of the offloading needed for a tensor."""
assert param_name in ["weight_offloading", "activation_offloading"]
if tensor is None:
return
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
setattr(tensor, param_name, value)
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
setattr(tensor, param_name, value)
for tensor in tensors:
if tensor is None:
continue
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
tensor.activation_offloading = True
else:
data_tensors = tensor.get_data_tensors()
for tensor in data_tensors:
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
def is_cpu_offload_enabled() -> bool:
......@@ -459,8 +463,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
torch.cuda.current_stream().wait_stream(self.d2h_stream)
# Time to free the activation memory after usage
for tensor_tag, _ in self.tensor_tag_to_buf.items():
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
# Release the pointer to the tensor
self.tensor_tag_to_buf[tensor_tag] = None
# Time to offload the next group
......@@ -538,7 +549,7 @@ def get_cpu_offload_context(
num_layers: int = 1,
model_layers: int = 1,
offload_activations: bool = True,
offload_weights: bool = True,
offload_weights: bool = False,
):
"""
This function returns the CPU Offload context and the synchronizer function that needs to be
......@@ -570,28 +581,30 @@ def get_cpu_offload_context(
"""
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
# This includes the Gradient Accumulation Buffer
def tensor_need_offloading_checker_weights(tensor):
return hasattr(tensor, "weight_offloading")
def tensor_need_offloading_checker_all(tensor):
return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading")
if offload_activations and offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_all
elif offload_activations:
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
elif offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_weights
else:
if not offload_weights and not offload_activations:
raise ValueError(
"CPU Offloading is enabled while it is not "
"mentioned what to offload (weights/activations)"
)
if offload_weights:
import warnings
warnings.warn(
"Offloading weights is deprecated. Using offload_weights=True does not have any"
" effect.",
DeprecationWarning,
)
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
if not offload_activations:
return nullcontext(), lambda x: x
def tensor_need_offloading_checker_activations(tensor):
return hasattr(tensor, "activation_offloading")
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_model_group=model_layers,
......
......@@ -167,6 +167,38 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class Float8BlockQuantizer : public Quantizer {
public:
// Which float8 type is used for q data.
DType dtype;
// Options about how to quantize the tensor
// Quantization scales are rounded down to powers of 2.
bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0;
private:
int block_scaling_dim = 2;
public:
// Initializes from a python handle to a Float8BlockQuantizer
explicit Float8BlockQuantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override {
return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
}
// Gets rowwise and columnwise_data from tensor and sets them on wrapper
void set_quantization_params(TensorWrapper* tensor) const override;
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class MXFP8Quantizer : public Quantizer {
public:
DType dtype;
......
......@@ -50,11 +50,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
......@@ -63,8 +63,8 @@ std::vector<py::object> fused_attn_bwd(
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer);
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
......@@ -121,18 +121,22 @@ std::vector<at::Tensor> te_batchgemm_ts(
int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator);
#endif
namespace transformer_engine::pytorch {
/***************************************************************************************************
* Transpose
**************************************************************************************************/
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
std::optional<std::vector<py::handle>> output_list,
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list,
transformer_engine::DType otype);
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
std::optional<at::Tensor> output = std::nullopt);
} // namespace transformer_engine::pytorch
namespace transformer_engine::pytorch {
/***************************************************************************************************
......@@ -285,16 +289,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
**************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory);
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory);
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank);
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
/***************************************************************************************************
* Miscellaneous
......@@ -394,10 +396,25 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> padded_input_row_list);
/***************************************************************************************************
* swizzle
* NVSHMEM APIs
**************************************************************************************************/
void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans);
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group);
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal);
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize();
} // namespace nvshmem_api
/***************************************************************************************************
* swizzle
**************************************************************************************************/
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);
......
......@@ -50,7 +50,12 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
} else {
act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
}
......
......@@ -7,217 +7,181 @@
#include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory) {
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(3) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (s, b, h, d)
// output
auto act_options = at::TensorOptions().dtype(input.scalar_type()).device(input.device());
auto output = at::empty(input.sizes(), act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(input.size(2) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
// input sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
// const int t = input.size(0);
const int h = input.size(1);
const int d = input.size(2);
// input strides
const int stride_t = input.stride(0);
const int stride_h = input.stride(1);
const int stride_d = input.stride(2);
// batch size
const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b,
h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
return output;
}
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
// input sizes: (s, b, h, d) or (b, s, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(0) : input.size(1);
const int b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(1) : input.size(0);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(0) : input.stride(1);
const int stride_b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(1) : input.stride(0);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
at::Tensor output;
if (transpose_output_memory) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(d >= d2,
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s,
stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory) {
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(output_grads.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
auto act_options =
at::TensorOptions().dtype(output_grads.scalar_type()).device(output_grads.device());
auto input_grads = at::empty(output_grads.sizes(), act_options);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
// output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
// const int t = output_grads.size(0);
const int h = output_grads.size(1);
const int d = output_grads.size(2);
// output_grads strides
const int stride_t = output_grads.stride(0);
const int stride_h = output_grads.stride(1);
const int stride_d = output_grads.stride(2);
// batch size
const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank,
max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
return input_grads;
}
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(0) : output_grads.size(1);
const int b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(1) : output_grads.size(0);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(0) : output_grads.stride(1);
const int stride_b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(1) : output_grads.stride(0);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
at::Tensor input_grads;
if (transpose_output_memory) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h,
d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(2) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const int t = input.size(0);
const int h = input.size(1);
const int d = input.size(2);
// input strides
const int stride_t = input.stride(0);
const int stride_h = input.stride(1);
const int stride_d = input.stride(2);
// batch size
const int b = cu_seqlens.size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
auto output = torch::empty({t, h, d}, act_options);
// output strides
const int o_stride_t = output.stride(0);
const int o_stride_h = output.stride(1);
const int o_stride_d = output.stride(2);
auto input_cu = makeTransformerEngineTensor(input);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(d >= d2,
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const int t = output_grads.size(0);
const int h = output_grads.size(1);
const int d = output_grads.size(2);
// output_grads strides
const int stride_t = output_grads.stride(0);
const int stride_h = output_grads.stride(1);
const int stride_d = output_grads.stride(2);
// batch size
const int b = cu_seqlens.size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
auto input_grads = torch::empty({t, h, d}, act_options);
const int o_stride_t = input_grads.stride(0);
const int o_stride_h = input_grads.stride(1);
const int o_stride_d = input_grads.stride(2);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream());
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b,
h, d, d2, stride_s, stride_b, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
return input_grads;
}
......@@ -3,9 +3,11 @@
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "kv_cache.cuh"
#include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
......@@ -95,11 +97,11 @@ std::vector<py::object> fused_attn_fwd(
NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
#ifdef __HIP_PLATFORM_AMD__
assert(false);
#else
......@@ -289,8 +291,8 @@ std::vector<py::object> fused_attn_bwd(
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) {
#ifdef __HIP_PLATFORM_AMD__
assert(false);
......@@ -461,13 +463,13 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
auto temp_vec = std::vector<size_t>(tmp.begin(), tmp.end());
const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()};
const std::vector<int64_t> &signed_shape = Aux_CTX_Tensors[i].sizes().vec();
const std::vector<size_t> tmp(signed_shape.begin(), signed_shape.end());
NVTEBasicTensor temp_data = {
Aux_CTX_Tensors[i].data_ptr(),
static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())),
temp_shape};
nvte_make_shape(tmp.data(), tmp.size())};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
}
......
......@@ -46,6 +46,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if (te_output.numel() == 0) return out;
QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(te_noop.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
......@@ -61,15 +64,21 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
// this config is used for cs scaling factor computation
// because compute scale is cannot be fused with quantize kernel
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(),
at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
return out;
}
......
......@@ -157,15 +157,15 @@ void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr());
if (local_chunk) {
if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel())
if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != (int64_t)_ubuf.element_size())
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size();
} else {
if (input_tensor.numel() > (int64_t)_ubuf.numel())
if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != (int64_t)_ubuf.element_size())
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
}
......@@ -189,7 +189,7 @@ py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk,
std::vector<int64_t> torch_shape;
if (shape.has_value()) {
torch_shape = shape.value();
auto requested = product(torch_shape);
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
......@@ -253,18 +253,18 @@ void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bo
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
if (local_chunk) {
// Copy input to the target ubuf chunk by rank offset
if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel())
if (input_tensor.numel() * _tp_size > _ubuf.numel())
NVTE_ERROR("input is larger than the local communication buffer!");
if (input_tensor.element_size() != (int64_t)_ubuf.element_size())
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main));
} else {
if (input_tensor.numel() > (int64_t)_ubuf.numel())
if (input_tensor.numel() > _ubuf.numel())
NVTE_ERROR("input is larger than the global communication buffer!");
if (input_tensor.element_size() != (int64_t)_ubuf.element_size())
if (input_tensor.element_size() != _ubuf.element_size())
NVTE_ERROR("input data type does not match communication buffer!");
NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr,
input_tensor.numel() * input_tensor.element_size(),
......@@ -280,7 +280,7 @@ py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk,
std::vector<int64_t> torch_shape;
if (shape.has_value()) {
torch_shape = shape.value();
auto requested = product(torch_shape);
size_t requested = product(torch_shape);
auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel();
NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested,
") does not match allocated buffer size (", expected, ")!");
......
......@@ -21,6 +21,7 @@
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
namespace {
......@@ -179,8 +180,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
if (comm_overlap) {
// Prepare extra output tensor
TensorWrapper extra_output_tensor;
......@@ -317,17 +325,18 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers;
std::vector<at::Tensor> D_vectors;
// Keep the swizzled scaling factor tensors alive during the GEMMs.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto none = py::none();
std::vector<size_t> single_output_begins;
std::vector<size_t> single_output_ends;
int slicing_dim;
if (single_output && D == std::nullopt) {
NVTE_ERROR("not implemented, D should be allocated for single output case.");
}
void* output_data_ptr;
void* output_data_ptr = nullptr;
if (single_output) {
output_data_ptr = (*D)[0].data_ptr();
}
......@@ -384,6 +393,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
continue;
}
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa)));
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb)));
auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
......
......@@ -12,6 +12,8 @@
// #include <torch/all.h>
#include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
......@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor {
n -= chunk_idx * chunk_size;
for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8,
force_pow_2_scales, epsilon);
float scale_val = transformer_engine::compute_scale_from_amax(
amax[i_start], max_fp8, force_pow_2_scales, epsilon, std::numeric_limits<float>::max());
scale[i_start] = scale_val;
transformer_engine::reciprocal(scale_inv + i_start, scale_val);
}
......
......@@ -150,6 +150,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
......@@ -166,15 +167,18 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
}
return {out, py::cast(mu), py::cast(rsigma)};
......@@ -293,6 +297,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
QuantizationConfigWrapper quant_config;
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
......@@ -309,15 +314,18 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
} else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
at::cuda::getCurrentCUDAStream());
}
return {out, py::none(), py::cast(rsigma)};
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#ifdef NVTE_ENABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmem_api/nvshmem_waitkernel.h>
#include <nvshmemx.h>
#endif
#include <cuda.h>
#include <cuda_fp8.h>
#include <torch/cuda.h>
#include <torch/extension.h>
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {};
nvshmemx_uniqueid_t id = {};
int my_rank = process_group->getRank();
int num_ranks = process_group->getSize();
if (my_rank == 0) {
nvshmemx_get_uniqueid(&id);
}
auto backend_is_nccl = (process_group->getBackendType() == c10d::ProcessGroup::BackendType::NCCL);
NVTE_CHECK(backend_is_nccl, "Currently only support NCCL boostrap for NVSHMEM");
auto datatensor =
torch::from_blob(reinterpret_cast<void *>(&id),
{static_cast<int64_t>(sizeof(nvshmemx_uniqueid_t) / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto datatmp = (backend_is_nccl) ? datatensor.cuda() : datatensor;
c10d::BroadcastOptions bcast_opts;
bcast_opts.rootRank = 0;
std::vector<torch::Tensor> datachunk = {datatmp};
auto work = process_group->broadcast(datachunk, bcast_opts);
work->wait();
if (backend_is_nccl) {
datatensor.copy_(datatmp.cpu());
datatmp = torch::Tensor();
}
nvshmemx_set_attr_uniqueid_args(my_rank, num_ranks, &id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
NVTE_CHECK(my_rank == nvshmem_my_pe(), "my_rank: ", my_rank,
" != nvshmem_my_pe(): ", nvshmem_my_pe());
NVTE_CHECK(num_ranks == nvshmem_n_pes(), "num_ranks: ", num_ranks,
" != nvshmem_n_pes(): ", nvshmem_n_pes());
#else
NVTE_ERROR("Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind) {
#ifdef NVTE_ENABLE_NVSHMEM
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
cudaStream_t cur_stream = (cudaStream_t)at::cuda::getCurrentCUDAStream();
WaitKind wait_kind_enum = WaitKind::STREAM_WAIT;
if (wait_kind == "kernel") {
wait_kind_enum = WaitKind::KERNEL_WAIT;
} else if (wait_kind == "nvshmem") {
wait_kind_enum = WaitKind::NVSHMEM_WAIT;
} else if (wait_kind == "stream") {
wait_kind_enum = WaitKind::STREAM_WAIT;
} else {
NVTE_ERROR("Invalid wait kind: ", wait_kind);
}
nvshmem_wait_on_stream(sig_addr, wait_kind_enum, cur_stream);
#else
NVTE_ERROR(
"Internal TE error: nvshmem_wait_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype) {
#ifdef NVTE_ENABLE_NVSHMEM
auto option_gpu =
at::TensorOptions().dtype(dtype).device(at::kCUDA).device_index(c10::cuda::current_device());
auto size = torch::elementSize(dtype) *
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
return at::from_blob(
nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu);
#else
NVTE_ERROR("Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal) {
#ifdef NVTE_ENABLE_NVSHMEM
void *src_ptr = reinterpret_cast<void *>(src.data_ptr());
void *dst_ptr = reinterpret_cast<void *>(dst.data_ptr());
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
auto nelement = src.numel() * src.element_size();
uint64_t sigval = 1;
at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream();
nvshmemx_putmem_signal_on_stream(dst_ptr, src_ptr, nelement, sig_addr, sigval, NVSHMEM_SIGNAL_SET,
peer, (cudaStream_t)cur_stream);
#else
NVTE_ERROR(
"Internal TE error: nvshmem_send_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_finalize() {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmem_finalize();
#else
NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
} // namespace nvshmem_api
......@@ -17,7 +17,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
NVTE_CHECK(input.dim() == 2, "Dimension of input must equal 2.");
NVTE_CHECK(output.dim() == 2, "Dimension of output must equal 2.");
const int num_tensors = input_row_list.size();
const auto num_tensors = input_row_list.size();
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, output_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, output_shape_list;
......
......@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr,
num_tokens * topK);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
at::Tensor permuted_output = torch::empty(
{num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor permuted_output =
torch::empty({num_out_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor row_id_map = torch::empty(
{num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
......@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
using namespace transformer_engine::pytorch;
int num_cols = input.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc
at::Tensor unpermuted_output = torch::empty(
{num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor unpermuted_output =
torch::empty({num_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream().stream();
......@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input_bwd.scalar_type();
// Output buffer alloc
at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols},
torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor act_grad =
torch::empty({input_fwd.size(0), num_cols},
torch::dtype(input_bwd.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor prob_grad = torch::empty(
{num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
......
......@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
void init_float8_extension() {
if (Float8TensorPythonClass) return;
......@@ -61,9 +64,31 @@ void init_mxfp8_extension() {
"Internal error: could not initialize pyTorch MXFP8 extension.");
}
void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorBasePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
"transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base");
Float8BlockwiseQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer"));
Float8BlockwiseQTensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase"));
Float8BlockwiseQTensorPythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor"));
NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
}
void init_extension() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
}
} // namespace transformer_engine::pytorch
......@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("output") = py::none(), py::arg("noop") = py::none());
m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
py::arg("otype"));
m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize,
"Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer"));
m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)",
......@@ -170,15 +196,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose",
py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype"));
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
#ifdef USE_ROCM
m.def("te_batchgemm_ts", &te_batchgemm_ts, "Batched GEMM"); /// rocblas
#endif
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"),
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax"));
......@@ -206,10 +234,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format",
py::call_guard<py::gil_scoped_release>());
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
......@@ -240,6 +264,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
// nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());
// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment