Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -196,10 +196,10 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
if (cudnn_runtime_version >= 90300) { \
num_segments = input_batch * max_segments_per_seq; \
} else { \
size_t runtime_num_segments_q = \
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \
size_t runtime_num_segments_kv = \
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \
size_t runtime_num_segments_q = nvte_get_runtime_num_segments( \
q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \
size_t runtime_num_segments_kv = nvte_get_runtime_num_segments( \
kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \
num_segments = runtime_num_segments_q; \
......@@ -248,7 +248,7 @@ static void FusedAttnForwardImpl(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors;
......
......@@ -108,7 +108,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
float *amax_dptr = nullptr;
float *scale_dptr = nullptr;
auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
......
......@@ -44,6 +44,7 @@ enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING = 0,
DELAYED_TENSOR_SCALING = 1,
MXFP8_1D_SCALING = 2,
CURRENT_TENSOR_SCALING = 3,
};
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
......@@ -57,6 +58,9 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
case JAXX_Scaling_Mode::MXFP8_1D_SCALING:
return NVTEScalingMode::NVTE_MXFP8_1D_SCALING;
break;
case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING:
return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING;
break;
default:
NVTE_ERROR("Invalid Scaling Mode ", static_cast<int>(mode));
break;
......
......@@ -24,7 +24,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
// empty tensor wrappers are okay just to get workspace size
auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, in_dtype);
auto gamma_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
......@@ -98,7 +98,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto workspace_shape = std::vector<size_t>{workspace_size};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(gamma, gamma_shape, in_dtype);
auto gamma_tensor = TensorWrapper(gamma, gamma_shape, w_dtype);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - _sm_margin;
......@@ -107,6 +107,11 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), input_shape);
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling.");
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
......@@ -118,7 +123,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
}
......@@ -134,6 +139,8 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
}
if (_norm_type == NVTE_Norm_Type::LayerNorm) {
NVTE_CHECK(w_dtype == convert_ffi_datatype_to_te_dtype(beta_buf.element_type()),
"gamma and beta must have the same data type.");
auto beta_tensor = TensorWrapper(beta, gamma_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
......
......@@ -142,6 +142,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
......
......@@ -7,6 +7,7 @@
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/recipe.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
......@@ -107,18 +108,21 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
if (is_tensor_scaling) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
......@@ -142,11 +146,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
: output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
? scale_inv_buf
: colwise_scale_inv_buf;
auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
if (is_tensor_scaling) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
......@@ -159,6 +161,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
}
}
if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
......
......@@ -3,12 +3,26 @@
*
* See LICENSE for license information.
************************************************************************/
#include "utils.h"
#include "util.h"
#include <cuda_runtime_api.h>
#include "ATen/cuda/CUDAContextLight.h"
#include <cassert>
bool non_tn_fp8_gemm_supported() {
int major = at::cuda::getCurrentDeviceProperties()->major;
return major >= 10;
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion() {
int ver = 0;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
return ver;
}
size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); }
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
} // namespace jax
} // namespace transformer_engine
......@@ -4,9 +4,6 @@
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h>
......@@ -25,12 +22,6 @@ int GetCudaRuntimeVersion();
size_t GetCudnnRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
......@@ -63,28 +54,5 @@ class cudaDevicePropertiesManager {
cudaDeviceProp prop_;
};
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>
#include "common/util/cuda_runtime.h"
#include "utils.h"
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion() {
int ver = 0;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
return ver;
}
size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); }
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream) {
size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace jax
} // namespace transformer_engine
......@@ -49,6 +49,7 @@ def dense(
"""
# Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set:
x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims)
if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
......@@ -183,6 +184,7 @@ def _dense_bwd_rule(
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
"""
def grouped_dense(
x_list,
kernel_list,
......@@ -190,10 +192,8 @@ def grouped_dense(
contracting_dims_list,
quantizer_set_list=None,
):
"""
Perform grouped_dense layer transformation with optional quantization.
# Perform grouped_dense layer transformation with optional quantization.
"""
output_list = _grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
......@@ -315,3 +315,4 @@ def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
......@@ -11,7 +11,6 @@ from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import numpy as np
import jax.numpy as jnp
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
......@@ -65,6 +64,7 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters(
module,
norm_type,
shape,
scale_init,
......@@ -74,13 +74,21 @@ def _create_layernorm_parameters(
input_dtype,
dtype,
):
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
scale = scale.astype(input_dtype)
scale = module.param(
"scale",
nn.with_logical_partitioning(scale_init, scale_axes),
shape,
dtype,
).astype(input_dtype)
norm_type = canonicalize_norm_type(norm_type)
if norm_type == "layernorm":
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
bias = jnp.asarray(bias, input_dtype)
bias = module.param(
"ln_bias",
nn.with_logical_partitioning(bias_init, bias_axes),
shape,
dtype,
).astype(input_dtype)
else:
assert norm_type == "rmsnorm"
bias = None
......@@ -308,6 +316,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type,
(features,),
self.scale_init,
......@@ -467,16 +476,22 @@ class DenseGeneral(TransformerEngineBase):
"Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
)
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
kernel = self.param(
"kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
bias = self.param(
"bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
features,
self.dtype,
).astype(input_dtype)
else:
bias = None
......@@ -499,25 +514,21 @@ class DenseGeneral(TransformerEngineBase):
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
lora_a_kernel = self.param(
"lora_a_kernel",
self.kernel_init,
nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
).astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes(
lora_b_kernel = self.param(
"lora_b_kernel",
nn.initializers.zeros,
nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
lora_b_kernel_shape,
self.dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(input_dtype)
).astype(input_dtype)
y += _apply_low_rank_adaptation(
inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
......@@ -695,6 +706,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type,
(features,),
self.scale_init,
......@@ -730,8 +742,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim)
kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
kernel = self.param(
"kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
......@@ -770,25 +785,21 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
lora_a_kernel = self.param(
"lora_a_kernel",
self.kernel_init,
nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
).astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes(
lora_b_kernel = self.param(
"lora_b_kernel",
nn.initializers.zeros,
nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
lora_b_kernel_shape,
self.dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(input_dtype)
).astype(input_dtype)
z += _apply_low_rank_adaptation(
y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
......@@ -796,8 +807,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
bias = self.param(
"bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
features,
self.dtype,
).astype(input_dtype)
if bias is not None:
......@@ -1028,6 +1042,7 @@ class LayerNormMLP(TransformerEngineBase):
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type,
(features,),
self.scale_init,
......@@ -1067,14 +1082,13 @@ class LayerNormMLP(TransformerEngineBase):
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes(
kernel_1 = self.param(
"wi_kernel",
kernel_1_init,
nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
num_activations,
-2,
kernel_1_each_shape,
self.dtype,
axes=self.kernel_axes_1,
)
if not QuantizeConfig.is_fp8_enabled():
......@@ -1083,12 +1097,11 @@ class LayerNormMLP(TransformerEngineBase):
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
kernel_2 = nn_partitioning.param_with_axes(
kernel_2 = self.param(
"wo_kernel",
self.kernel_init,
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
kernel_2_shape,
self.dtype,
axes=self.kernel_axes_2,
)
if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
......@@ -1097,21 +1110,19 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias:
bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes(
bias_1 = self.param(
"wi_bias",
self.bias_init,
nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
bias_1_shape,
self.dtype,
axes=self.bias_axes_1,
).astype(input_dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
bias_2 = self.param(
"wo_bias",
self.bias_init,
nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
bias_2_shape,
self.dtype,
axes=self.bias_axes_2,
).astype(input_dtype)
else:
bias_1 = None
......@@ -1168,9 +1179,13 @@ class LayerNormMLP(TransformerEngineBase):
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
)
if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
*get_non_contracting_logical_axes(
kernel_1.ndim, self.kernel_axes_1, contract_ind
),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
......@@ -1180,16 +1195,14 @@ class LayerNormMLP(TransformerEngineBase):
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
wi_lora_a_kernel = nn_partitioning.param_with_axes(
wi_lora_a_kernel = self.param(
"wi_lora_a_kernel",
kernel_1_init,
nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
num_activations,
-2,
wi_lora_a_kernel_each_shape,
self.dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
).astype(input_dtype)
wi_lora_b_kernel_shape = (
num_activations,
......@@ -1197,14 +1210,12 @@ class LayerNormMLP(TransformerEngineBase):
self.intermediate_dim,
)
wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
wi_lora_b_kernel = nn_partitioning.param_with_axes(
wi_lora_b_kernel = self.param(
"wi_lora_b_kernel",
nn.initializers.zeros,
nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
wi_lora_b_kernel_shape,
self.dtype,
axes=wi_lora_b_kernel_axes,
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype)
).astype(input_dtype)
x += _apply_low_rank_adaptation(
y,
......@@ -1253,25 +1264,21 @@ class LayerNormMLP(TransformerEngineBase):
if self.enable_low_rank_adaptation:
wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
wo_lora_a_kernel = nn_partitioning.param_with_axes(
wo_lora_a_kernel = self.param(
"wo_lora_a_kernel",
self.kernel_init,
nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
wo_lora_a_kernel_shape,
self.dtype,
axes=wo_lora_a_kernel_axes,
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype)
).astype(input_dtype)
wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
wo_lora_b_kernel = nn_partitioning.param_with_axes(
wo_lora_b_kernel = self.param(
"wo_lora_b_kernel",
nn.initializers.zeros,
nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
wo_lora_b_kernel_shape,
self.dtype,
axes=wo_lora_b_kernel_axes,
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype)
).astype(input_dtype)
out += _apply_low_rank_adaptation(
z,
......
......@@ -15,7 +15,6 @@ import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import nn as jax_nn
from jax import random as jax_random
......@@ -1503,12 +1502,11 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)
# Compute relative attention bias
relative_attention_bias = nn_partitioning.param_with_axes(
relative_attention_bias = self.param(
"rel_embedding",
self.embedding_init,
nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
(self.num_attention_heads, self.num_buckets),
self.dtype,
axes=self.embedding_axes,
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
......
......@@ -275,6 +275,7 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims),
)
if dot_1_input_axes is not None and kernel_1_axes is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
......@@ -303,12 +304,6 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims),
)
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2:
bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
......
......@@ -85,6 +85,7 @@ class Dequantizer:
funcs = {
ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.CURRENT_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling,
}
......
......@@ -94,7 +94,7 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
A tuple of (bool, str) indicating support and any error message
"""
gpu_arch = get_device_compute_capability(gpu_id)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if scaling_mode.is_tensor_scaling():
return _check_delayed_scaling_fp8_support(gpu_arch)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _check_block_scaling_fp8_support(gpu_arch)
......@@ -182,6 +182,8 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.MXFP8_1D_SCALING
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return ScalingMode.CURRENT_TENSOR_SCALING
raise ValueError("Invalid fp8_recipe!")
......@@ -240,7 +242,7 @@ class QuantizeConfig:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin
cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
......@@ -309,6 +311,30 @@ class DelayedScalingQuantizeConfig:
QuantizeConfig.finalize()
class CurrentScalingQuantizeConfig:
"""Configuration class for current scaling FP8 recipe.
This class provides specific initialization and finalization for current scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize current scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
@staticmethod
def finalize() -> None:
"""Reset the current scaling configuration."""
QuantizeConfig.finalize()
class BlockScalingQuantizeConfig:
"""Configuration class for block scaling FP8 recipe.
......@@ -385,6 +411,8 @@ def fp8_autocast(
Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
Config = CurrentScalingQuantizeConfig
try:
with global_shard_guard(mesh_resource):
......
......@@ -27,13 +27,36 @@ __all__ = [
"QuantizeLayout",
"Quantizer",
"QuantizerSet",
"CurrentScaleQuantizer",
"DelayedScaleQuantizer",
"BlockScaleQuantizer",
"QuantizerFactory",
"noop_quantizer_set",
"compute_scale_from_amax",
]
def compute_scale_from_amax(
amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
"""Compute scale from amax value.
Args:
amax: Maximum absolute value of the tensor
q_dtype: Quantization data type
Returns:
Scale value
"""
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None:
scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
return sf
@register_pytree_node_class
@dataclass
class Quantizer(ABC):
......@@ -159,37 +182,19 @@ class Quantizer(ABC):
@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(Quantizer):
"""Quantizer implementation using delayed scaling.
class CurrentScaleQuantizer(Quantizer):
"""Quantizer implementation using current scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
This quantizer uses current scaling mode with float32 scales
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
)
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
def get_data_layout(self) -> str:
"""Get the data data_layout string.
......@@ -217,15 +222,18 @@ class DelayedScaleQuantizer(Quantizer):
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
compute_dtype = self.scale.dtype
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale
amax = jnp.max(jnp.abs(x)).reshape((1,)).astype(compute_dtype)
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
scaled_x = x.astype(compute_dtype) * scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
......@@ -233,8 +241,7 @@ class DelayedScaleQuantizer(Quantizer):
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,)))
scale_inv = 1.0 / scale
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
scale_inv=scale_inv,
......@@ -294,6 +301,75 @@ class DelayedScaleQuantizer(Quantizer):
return colwise_tensor
return rowwise_tensor
@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(CurrentScaleQuantizer):
"""Quantizer implementation using delayed scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
)
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
scaled_x = x.astype(compute_dtype) * self.scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / self.scale
self.update(jnp.max(jnp.abs(x)).reshape((1,)))
return ScaledTensorFactory.create_1x(
data=clipped_scaled_x,
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
@staticmethod
@jax.jit
def _update_amax_history(amax_history, new_amax):
......@@ -323,18 +399,12 @@ class DelayedScaleQuantizer(Quantizer):
Updated scale value
"""
# 2. Calculate the current scale
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True)
else:
amax = amax_history[0:1]
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = scale.at[0].set(sf[0])
return scale
return compute_scale_from_amax(amax, q_dtype, scale=scale)
@staticmethod
@jax.jit
......@@ -531,6 +601,7 @@ class QuantizerFactory:
quantizer_type_map = {
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
}
......
......@@ -95,10 +95,10 @@ class ScalingModeMetadataImpl(ABC):
"""
class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for delayed scaling mode.
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for current scaling mode.
This implementation provides metadata for delayed scaling mode, including scale data type and shape.
This implementation provides metadata for current scaling mode, including scale data type and shape.
"""
def get_scale_dtype(self) -> jnp.dtype:
......@@ -148,6 +148,13 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {})
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
"""Implementation for delayed scaling mode.
This implementation provides metadata for delayed scaling mode, including scale data type and shape.
"""
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""Implementation for block scaling mode.
......@@ -317,12 +324,14 @@ class ScalingMode(Enum):
This class defines the available scaling modes for tensor quantization:
- DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
- NO_SCALING: No scaling applied
"""
NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING
def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode.
......@@ -395,6 +404,25 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
def is_tensor_scaling(self) -> bool:
"""Check if this scaling mode is per-tensor scaling.
Returns:
True if the scaling mode is tensor scaling, False otherwise
"""
return self in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
)
def is_1d_block_scaling(self) -> bool:
"""Check if this scaling mode is 1D block scaling.
Returns:
True if the scaling mode is 1D block scaling, False otherwise
"""
return self == ScalingMode.MXFP8_1D_SCALING
def __eq__(self, other):
"""Compare this scaling mode with another.
......@@ -434,5 +462,6 @@ SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
# WAR
ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(),
}
......@@ -84,6 +84,7 @@ if __name__ == "__main__":
The script requires JAX to be installed for building.
It will raise a RuntimeError if JAX is not available.
"""
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir))
......@@ -100,6 +101,17 @@ if __name__ == "__main__":
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=[
"jax[cuda12]",
"flax>=0.7.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy"],
)
......
......@@ -13,7 +13,7 @@ import os
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Callable
from typing import Callable, Optional
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
......@@ -112,9 +112,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
return jax.lax.with_sharding_constraint(x, pspec)
def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
def with_sharding_constraint_by_logical_axes(
x: jnp.array, logical_axis_names: Optional[tuple | list]
):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
If logical_axis_names = None, this means no sharding constraint is applied.
If logical_axis_names = (None, None, ...), this means a sharding constraint is applied and the tensor is replicated across all devices.
Args:
x: Input tensor to apply sharding constraint
logical_axis_names: Logical axis names to apply sharding constraint
Returns:
Tensor with sharding constraint applied, or the original tensor if no logical axes are provided.
"""
if not logical_axis_names:
return x
......@@ -321,7 +334,9 @@ class ShardingType(Enum):
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
def get_non_contracting_logical_axes(
ndim, logical_axes: tuple[Optional[str]], contracting_dims
) -> tuple[Optional[str]]:
"""Get logical axes for non-contracting dimensions.
Args:
......@@ -332,11 +347,8 @@ def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if not logical_axes:
logical_axes = (None,) * ndim
elif len(logical_axes) < ndim:
logical_axes = logical_axes + (None,) * (ndim - len(logical_axes))
assert len(logical_axes) == ndim
assert logical_axes is not None, "Logical axes must be a tuple and cannot be None."
assert len(logical_axes) == ndim, "Logical axes must match the number of dimensions."
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
......
......@@ -4,22 +4,14 @@
"""Transformer Engine bindings for pyTorch"""
# pylint: disable=wrong-import-position,wrong-import-order
# pylint: disable=wrong-import-position
import logging
import functools
import sys
import importlib
import importlib.util
from importlib.metadata import version
from packaging.version import Version as PkgVersion
import torch
from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
from transformer_engine.common import load_framework_extension
@functools.lru_cache(maxsize=None)
......@@ -28,57 +20,10 @@ def torch_version() -> tuple[int, ...]:
return PkgVersion(str(torch.__version__)).release
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_torch"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
"'pip3 install transformer-engine[pytorch]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
_logger.info(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[pytorch]==VERSION'",
module_name,
)
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
try:
so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."
_load_library()
load_framework_extension("torch")
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP
......@@ -90,7 +35,8 @@ from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import (
moe_permute,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment