Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
......@@ -77,7 +77,7 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream);
}
......@@ -467,10 +467,10 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor), per_tensor,
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream);
}
......@@ -485,9 +485,9 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor),
*reinterpret_cast<Tensor *>(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor),
*convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
}
......@@ -124,7 +124,7 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
using namespace transformer_engine;
multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream);
}
......@@ -196,7 +196,7 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *reinterpret_cast<Tensor*>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
}
......@@ -39,8 +39,6 @@ Compute always in FP32
namespace transformer_engine {
namespace normalization {
bool& use_zero_centered_gamma_in_weight_dtype();
#ifndef __HIP_PLATFORM_AMD__
cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
......@@ -51,13 +49,17 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
bool is_tuned, NVTEScalingMode mode, bool training) {
// TODO: Add scaling_mode to general_key is needed
uint64_t general_key = static_cast<uint32_t>(itype) | (static_cast<uint32_t>(otype) << 3) |
(static_cast<uint32_t>(ctype) << 6) | (static_cast<uint32_t>(wtype) << 9) |
(uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 |
(uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) |
(uint32_t(mode) << 19) | (uint32_t(training) << 22);
bool is_tuned, NVTEScalingMode mode, bool training,
bool gamma_in_weight_dtype) {
static_assert(NVTE_INVALID_SCALING < 1024,
"This function assumes at most 10 bits used in the scaling mode.");
static_assert(kNVTENumTypes < 32, "This function assumes at most 5 bits used in the NVTEDType");
uint64_t general_key = static_cast<uint64_t>(itype) | (static_cast<uint64_t>(otype) << 5) |
(static_cast<uint64_t>(ctype) << 10) |
(static_cast<uint64_t>(wtype) << 15) | (uint64_t(NormType) << 20) |
(uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) |
(uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) |
(uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38);
return std::make_tuple(general_key, batch_size, hidden_size, is_tuned);
}
......@@ -216,8 +218,11 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
}
const auto gamma_dtype = use_zero_centered_gamma_in_weight_dtype() ? wtype : ctype;
NVTE_CHECK(gamma_dtype == DType::kFloat32 || gamma_dtype == DType::kFloat16 ||
gamma_dtype == DType::kBFloat16,
"Gamma of type FP4 is not supported");
_scalar_dptr = std::make_unique<char[]>(typeToSize(gamma_dtype));
_scalar_dptr = std::make_unique<char[]>(typeToNumBits(gamma_dtype) / 8);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gamma_dtype, cpp_dtype,
*(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);
......@@ -490,11 +495,12 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
const NVTEScalingMode mode, const bool training) {
const NVTEScalingMode mode, const bool training, const bool gamma_in_weight_dtype) {
const DType ctype = DType::kFloat32;
bool is_tuned = is_aligned && (batch_size % 4 == 0);
auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size,
hidden_size, zero_centered_gamma, is_tuned, mode, training);
auto key =
get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size,
zero_centered_gamma, is_tuned, mode, training, gamma_in_weight_dtype);
auto it = normalizationPlanMap.find(key);
if (it != normalizationPlanMap.end()) {
......@@ -577,6 +583,7 @@ void nvte_enable_cudnn_norm_bwd(bool enable) {
#endif
}
// Only for testing, not thread-safe
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) {
NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype);
#ifdef USE_ROCM
......
......@@ -163,7 +163,7 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING,
bool training = true);
bool training = true, bool gamma_in_weight_dtype = false);
template <typename KernelParamsType>
class TeNormalizationRegistry {
......@@ -313,7 +313,8 @@ class NormalizationPlanRegistry {
NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true);
const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true,
const bool gamma_in_weight_dtype = false);
private:
NormalizationPlanRegistry() {}
......@@ -392,6 +393,8 @@ bool is_ptr_aligned(const Args*... ptrs) {
bool use_cudnn_norm_fwd();
bool use_cudnn_norm_bwd();
bool& use_zero_centered_gamma_in_weight_dtype();
} // namespace normalization
} // namespace transformer_engine
......
......@@ -15,6 +15,7 @@
#include "../../common.h"
#include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
......@@ -71,9 +72,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
#endif
bool gamma_in_weight_dtype = false;
if (cudnn_backend) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr,
......@@ -90,7 +93,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
z->data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training);
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
......@@ -108,11 +112,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
// Compute FP8 transpose if required
if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) {
Tensor transpose_data;
transpose_data.data = z->columnwise_data;
transpose_data.scaling_mode = z->scaling_mode;
nvte_transpose(reinterpret_cast<NVTETensor>(z), reinterpret_cast<NVTETensor>(&transpose_data),
stream);
NVTETensor transpose_data = nvte_create_tensor(z->scaling_mode);
Tensor& t = *convertNVTETensor(transpose_data);
t.data = z->columnwise_data;
nvte_transpose(static_cast<NVTETensor>(*z), transpose_data, stream);
nvte_destroy_tensor(transpose_data);
}
return;
......@@ -157,9 +161,11 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
bool gamma_in_weight_dtype = false;
if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr,
......@@ -172,7 +178,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
gamma.data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
......@@ -195,11 +202,10 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const bool zero_centered_gamma, cudaStream_t stream) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma),
reinterpret_cast<Tensor*>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
layernorm_fwd(*convertNVTETensorCheck(x), *convertNVTETensorCheck(gamma),
*convertNVTETensorCheck(beta), epsilon, convertNVTETensor(z), convertNVTETensor(mu),
convertNVTETensor(rsigma), convertNVTETensor(workspace), multiprocessorCount,
zero_centered_gamma, stream);
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
......@@ -212,10 +218,9 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
cudaStream_t stream) {
NVTE_API_CALL(nvte_layernorm_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
layernorm_bwd(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x),
*convertNVTETensorCheck(mu), *convertNVTETensorCheck(rsigma),
*convertNVTETensorCheck(gamma), convertNVTETensor(dx), convertNVTETensor(dgamma),
convertNVTETensor(dbeta), convertNVTETensor(workspace), multiprocessorCount,
zero_centered_gamma, stream);
}
......@@ -13,6 +13,7 @@
#include "../../common.h"
#include "../common.h"
#include "transformer_engine/normalization.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
......@@ -60,9 +61,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool training =
is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr;
bool gamma_in_weight_dtype = false;
if (cudnn_backend) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr);
......@@ -75,7 +78,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
z->data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training);
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
......@@ -93,11 +97,12 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
// Compute FP8 transpose if required
if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) {
Tensor transpose_data;
transpose_data.data = z->columnwise_data;
transpose_data.scaling_mode = z->scaling_mode;
nvte_transpose(reinterpret_cast<NVTETensor>(z), reinterpret_cast<NVTETensor>(&transpose_data),
stream);
NVTETensor transpose_data = nvte_create_tensor(z->scaling_mode);
auto *t = convertNVTETensor(transpose_data);
t->data = z->columnwise_data;
nvte_transpose(static_cast<NVTETensor>(*z), transpose_data, stream);
nvte_destroy_tensor(transpose_data);
}
return;
......@@ -133,9 +138,11 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
bool gamma_in_weight_dtype = false;
if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
......@@ -148,7 +155,8 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
gamma.data.dtype, // otype
x.data.shape[0], // batch_size
x.data.shape[1], // hidden_size
multiprocessorCount, zero_centered_gamma, is_aligned);
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
workspace->data.shape = plan->getWorkspaceShape();
......@@ -171,10 +179,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma),
reinterpret_cast<Tensor *>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
rmsnorm_fwd(*convertNVTETensorCheck(x), *convertNVTETensorCheck(gamma), epsilon,
convertNVTETensor(z), convertNVTETensor(rsigma), convertNVTETensor(workspace),
multiprocessorCount, zero_centered_gamma, stream);
}
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
......@@ -186,9 +193,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
rmsnorm_bwd(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x),
*convertNVTETensorCheck(rsigma), *convertNVTETensorCheck(gamma),
convertNVTETensor(dx), convertNVTETensor(dgamma), convertNVTETensor(workspace),
multiprocessorCount, zero_centered_gamma, stream);
}
......@@ -334,22 +334,16 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
const NVTETensor input_fwd, const int num_rows, const int topK,
const int num_cols, const int num_out_tokens, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_API_CALL(nvte_permute);
const transformer_engine::Tensor *input_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input);
const transformer_engine::Tensor *output_cu =
reinterpret_cast<const transformer_engine::Tensor *>(output);
const transformer_engine::Tensor *sorted_row_id_cu =
reinterpret_cast<const transformer_engine::Tensor *>(sorted_row_id);
const transformer_engine::Tensor *row_id_map_cu =
reinterpret_cast<const transformer_engine::Tensor *>(row_id_map);
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
const transformer_engine::Tensor *prob_grad_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob_grad);
const transformer_engine::Tensor *input_fwd_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input_fwd);
const Tensor *input_cu = convertNVTETensorCheck(input);
const Tensor *output_cu = convertNVTETensorCheck(output);
const Tensor *sorted_row_id_cu = convertNVTETensorCheck(sorted_row_id);
const Tensor *row_id_map_cu = convertNVTETensorCheck(row_id_map);
const Tensor *prob_cu = convertNVTETensorCheck(prob);
const Tensor *prob_grad_cu = convertNVTETensorCheck(prob_grad);
const Tensor *input_fwd_cu = convertNVTETensorCheck(input_fwd);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T,
......@@ -366,16 +360,13 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream) {
using namespace transformer_engine;
NVTE_API_CALL(nvte_unpermute);
const transformer_engine::Tensor *input_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input);
const transformer_engine::Tensor *output_cu =
reinterpret_cast<const transformer_engine::Tensor *>(output);
const transformer_engine::Tensor *row_id_map_cu =
reinterpret_cast<const transformer_engine::Tensor *>(row_id_map);
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
const Tensor *input_cu = convertNVTETensorCheck(input);
const Tensor *output_cu = convertNVTETensorCheck(output);
const Tensor *row_id_map_cu = convertNVTETensorCheck(row_id_map);
const Tensor *prob_cu = convertNVTETensorCheck(prob);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T,
......
......@@ -180,6 +180,7 @@ class DelayedScaling(Recipe):
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
......@@ -192,42 +193,12 @@ class DelayedScaling(Recipe):
class Float8CurrentScaling(Recipe):
"""
Use the per-tensor current scaling factor strategy.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0}
used for quantization of gradient tensor dY
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_dpa: bool, default = `False`
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
`fp8_mha = False, fp8_dpa = True`, a typical MHA module works as
`LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`.
When `fp8_mha = True, fp8_dpa = True`, it becomes
`LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`.
Notes
-----
* `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are
subject to change in future Transformer Engine releases.
"""
fp8_format: Format = Format.HYBRID
......@@ -242,9 +213,13 @@ class Float8CurrentScaling(Recipe):
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert (
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8CurrentScaling."
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
......@@ -291,7 +266,11 @@ class MXFP8BlockScaling(Recipe):
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
def __repr__(self) -> str:
return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]},"
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
)
@dataclass()
......@@ -313,32 +292,12 @@ class Float8BlockScaling(Recipe):
NOTE: To relax the default constraint that scales be powers of 2, set env variable
NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults.
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
Or initialize the Recipe with non-default QParams in code for increased control.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of gradient tensor dY
x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for x.
w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for w.
grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for grad.
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
"""
use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1"
......@@ -372,9 +331,13 @@ class Float8BlockScaling(Recipe):
assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop."
assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad."
assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad."
assert (
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8BlockScaling."
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
......
......@@ -112,7 +112,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check input tensor
NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
const auto &input = *reinterpret_cast<const Tensor *>(input_);
const auto &input = *convertNVTETensorCheck(input_);
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor for amax computation must unquantized, "
"but got scaling_mode=",
......@@ -125,7 +125,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_);
auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
......@@ -170,7 +170,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_);
auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
......
......@@ -397,9 +397,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
using namespace transformer_engine;
delayed_scaling_recipe::amax_and_scale_update(
*reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale),
amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
*convertNVTETensorCheck(amax_history), *convertNVTETensorCheck(scale),
convertNVTETensor(updated_amax_history), convertNVTETensor(updated_scale), amax_compute_algo,
static_cast<DType>(fp8_dtype), margin, stream);
}
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
......@@ -411,10 +411,10 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
size_t num_tensors = amax_histories.size();
std::vector<Tensor*> t_amax_histories, t_scales;
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i]));
t_scales.push_back(reinterpret_cast<Tensor*>(scales[i]));
t_amax_histories.push_back(convertNVTETensor(amax_histories[i]));
t_scales.push_back(convertNVTETensor(scales[i]));
}
delayed_scaling_recipe::amax_and_scale_update_after_reduction(
*reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales,
amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
*convertNVTETensorCheck(amax_reduction_buffer), t_amax_histories, t_scales, amax_compute_algo,
static_cast<DType>(fp8_dtype), margin, stream);
}
......@@ -244,8 +244,8 @@ void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETenso
NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(amax), h, w,
amax_stride_h, amax_stride_w, start_offset, block_len, stream);
*convertNVTETensorCheck(inp), *convertNVTETensorCheck(amax), h, w, amax_stride_h,
amax_stride_w, start_offset, block_len, stream);
}
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
......@@ -256,7 +256,7 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_partial_cast(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(out),
*reinterpret_cast<const Tensor *>(scale), h, w, scale_stride_h, scale_stride_w, start_offset,
block_len, static_cast<DType>(out_dtype), stream);
*convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), h,
w, scale_stride_h, scale_stride_w, start_offset, block_len, static_cast<DType>(out_dtype),
stream);
}
......@@ -514,6 +514,5 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_scaling_factors);
using namespace transformer_engine;
swizzle_scaling_factors(reinterpret_cast<const Tensor*>(input), reinterpret_cast<Tensor*>(output),
stream);
swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
}
......@@ -6,20 +6,27 @@
#include <transformer_engine/transformer_engine.h>
#include <atomic>
#include <climits>
#include <cstring>
#include <iostream>
#include <mutex>
#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
namespace transformer_engine {
size_t typeToSize(const DType type) {
size_t typeToNumBits(const DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size;); // NOLINT(*)
}
bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; }
size_t typeToSize(const DType type) {
NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
return typeToNumBits(type) / 8;
}
std::string to_string(const DType type) {
switch (type) {
......@@ -37,6 +44,8 @@ std::string to_string(const DType type) {
return "Float8E5M2";
case DType::kFloat8E8M0:
return "Float8E8M0";
case DType::kFloat4E2M1:
return "Float4E2M1";
case DType::kInt32:
return "Int32";
case DType::kInt64:
......@@ -52,6 +61,8 @@ std::string to_string(const NVTEScalingMode &mode) {
return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING:
return "NVTE_MXFP8_1D_SCALING";
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING";
case NVTE_INVALID_SCALING:
return "NVTE_INVALID_SCALING";
}
......@@ -81,10 +92,13 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t.columnwise_scale_inv.shape, ")");
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING ||
t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul, 4ul};
size_t expected_x, expected_y, alignment;
const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16;
const size_t block_size_colwise = 32;
if (t.has_data()) {
alignment = block_alignment[0];
......@@ -92,7 +106,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1];
expected_y =
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(32)), alignment) * alignment;
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
......@@ -101,7 +116,8 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
if (t.has_columnwise_data()) {
alignment = block_alignment[1];
expected_x =
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(32)), alignment) * alignment;
DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(block_size_colwise)), alignment) *
alignment;
alignment = block_alignment[0];
expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
......@@ -192,24 +208,139 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
CheckScaleTensorShape(t, name);
}
class TensorAllocator {
public:
static TensorAllocator &instance() {
static TensorAllocator allocator;
return allocator;
}
~TensorAllocator() {}
NVTETensor Allocate(NVTEScalingMode mode) {
std::lock_guard<std::mutex> lock(mutex);
if (!free_list.empty()) {
uintptr_t index = free_list.back();
NVTETensor ret = reinterpret_cast<NVTETensor>(index);
free_list.pop_back();
if (debug) {
std::cout << "Allocated " << index
<< " from free list. Free list size: " << free_list.size() << " and capacity "
<< free_list.capacity() << std::endl;
}
// 1-based indexing
memory[index - 1].scaling_mode = mode;
return ret;
}
if (memory.size() < memory.capacity()) {
memory.emplace_back();
Tensor &t = memory.back();
size = memory.size();
// 1-based indexing
uintptr_t index = memory.size();
if (debug) {
std::cout << "Allocated " << index << ". Memory size: " << memory.size() << " and capacity "
<< memory.capacity() << std::endl;
}
t.scaling_mode = mode;
t.nvte_tensor = reinterpret_cast<NVTETensor>(index);
return reinterpret_cast<NVTETensor>(index);
}
NVTE_ERROR("Cannot allocate a new NVTETensor. Maximum number of tensors reached: ",
MAX_TENSOR_NUM, ". There is probably a memory leak in your application.");
}
void Free(NVTETensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
memory[index - 1].clear();
if (debug) {
std::cout << "Freed " << index << ". Free list size: " << free_list.size() << " and capacity "
<< free_list.capacity() << std::endl;
}
}
void Free(NVTETensor *t, size_t N) {
std::lock_guard<std::mutex> lock(mutex);
for (size_t i = 0; i < N; ++i) {
uintptr_t index = reinterpret_cast<uintptr_t>(t[i]);
if (index == 0) continue;
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
memory[index - 1].clear();
}
if (debug) {
std::cout << "Freed range of" << N << " tensors. Free list size: " << free_list.size()
<< " and capacity " << free_list.capacity() << std::endl;
}
}
Tensor *convertNVTETensor(NVTETensor t) {
uintptr_t index = reinterpret_cast<uintptr_t>(t);
// 1-based indexing to enable 0-initialization of NVTETensor
// to be invalid tensor
static_assert(nullptr == 0);
if (index != 0 && index <= size) {
return &(memory[index - 1]);
}
return nullptr;
}
void setDebug(bool debug) {
std::lock_guard<std::mutex> lock(mutex);
this->debug = debug;
}
private:
TensorAllocator() {
std::lock_guard<std::mutex> lock(mutex);
memory.reserve(MAX_TENSOR_NUM);
}
std::mutex mutex;
std::atomic<size_t> size;
// Allocate at most 20 MB for tensors
// Should be replaced by virtual memory allocation
const size_t MAX_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(Tensor);
std::vector<uintptr_t> free_list;
std::vector<Tensor> memory;
bool debug = false;
};
Tensor *convertNVTETensor(const NVTETensor t) {
return TensorAllocator::instance().convertNVTETensor(t);
}
Tensor *convertNVTETensorCheck(const NVTETensor t) {
Tensor *ptr = TensorAllocator::instance().convertNVTETensor(t);
NVTE_CHECK(ptr != nullptr, "Invalid tensor.");
return ptr;
}
} // namespace transformer_engine
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->scaling_mode = scaling_mode;
NVTETensor ret = transformer_engine::TensorAllocator::instance().Allocate(scaling_mode);
return ret;
}
void nvte_destroy_tensor(NVTETensor tensor) {
if (tensor == nullptr) return;
auto *t = reinterpret_cast<transformer_engine::Tensor *>(tensor);
delete t;
transformer_engine::TensorAllocator::instance().Free(tensor);
}
void nvte_destroy_tensors(NVTETensor *tensors, size_t N) {
transformer_engine::TensorAllocator::instance().Free(tensors, N);
}
NVTEDType nvte_tensor_type(const NVTETensor tensor) {
if (tensor == nullptr) return kNVTEFloat32;
return static_cast<NVTEDType>(
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype());
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return kNVTEFloat32;
return static_cast<NVTEDType>(t->dtype());
}
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
......@@ -227,23 +358,24 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
NVTE_ERROR("Invalid tensor");
}
// Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
std::vector<size_t> shape = t.shape();
const std::vector<size_t> &shape = t->shape();
return nvte_make_shape(shape.data(), shape.size());
}
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
if (tensor == nullptr) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
NVTE_ERROR("Invalid tensor");
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size());
const std::vector<size_t> &shape = t->columnwise_data.shape;
return nvte_make_shape(shape.data(), shape.size());
}
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
......@@ -264,83 +396,97 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
return numel;
}
size_t nvte_tensor_element_size_bits(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return 8 * sizeof(float);
return transformer_engine::typeToNumBits(t->dtype());
}
size_t nvte_tensor_element_size(const NVTETensor tensor) {
if (tensor == nullptr) return sizeof(float);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return transformer_engine::typeToSize(t.dtype());
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return sizeof(float);
NVTE_CHECK(!is_fp4_dtype(t->dtype()),
"For FP4 type please use the nvte_tensor_element_size_bits.");
return nvte_tensor_element_size_bits(tensor) / 8;
}
size_t nvte_tensor_size_bytes(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return 0;
return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8;
}
void *nvte_tensor_data(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr;
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
return t->data.dptr;
}
void *nvte_tensor_columnwise_data(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.columnwise_data.dptr;
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
return t->columnwise_data.dptr;
}
float *nvte_tensor_amax(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
NVTE_CHECK(t->amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!");
return reinterpret_cast<float *>(t.amax.dptr);
return reinterpret_cast<float *>(t->amax.dptr);
}
float *nvte_tensor_scale(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
NVTE_CHECK(t->scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!");
return reinterpret_cast<float *>(t.scale.dptr);
return reinterpret_cast<float *>(t->scale.dptr);
}
float *nvte_tensor_scale_inv(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return reinterpret_cast<float *>(t.scale_inv.dptr);
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
return reinterpret_cast<float *>(t->scale_inv.dptr);
}
void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.columnwise_scale_inv.dptr;
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
return t->columnwise_scale_inv.dptr;
}
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
if (tensor == nullptr) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
return nvte_make_shape(nullptr, 0);
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size());
return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
}
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
const NVTEBasicTensor *param) {
NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated.");
auto &t = *reinterpret_cast<transformer_engine::Tensor *>(*tensor);
auto *t = transformer_engine::convertNVTETensor(*tensor);
NVTE_CHECK(t != nullptr, "Tensor is not allocated.");
switch (param_name) {
case kNVTERowwiseData:
t.data = *param;
t->data = *param;
break;
case kNVTEColumnwiseData:
t.columnwise_data = *param;
t->columnwise_data = *param;
break;
case kNVTEScale:
t.scale = *param;
t->scale = *param;
break;
case kNVTEAmax:
t.amax = *param;
t->amax = *param;
break;
case kNVTERowwiseScaleInv:
t.scale_inv = *param;
t->scale_inv = *param;
break;
case kNVTEColumnwiseScaleInv:
t.columnwise_scale_inv = *param;
t->columnwise_scale_inv = *param;
break;
default:
NVTE_ERROR("Unknown tensor parameter!");
......@@ -351,7 +497,7 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
switch (param_name) {
case kNVTERowwiseData:
return t.data;
......@@ -371,28 +517,30 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
}
NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
if (tensor == nullptr) {
return NVTE_DELAYED_TENSOR_SCALING;
}
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
return t.scaling_mode;
}
void nvte_tensor_pack_create(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
pack->tensors[i] =
transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING);
}
}
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
auto *t = reinterpret_cast<transformer_engine::Tensor *>(pack->tensors[i]);
delete t;
}
transformer_engine::TensorAllocator::instance().Free(pack->tensors, pack->MAX_SIZE);
}
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
if (tensor == nullptr) return;
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
// Zero out tensor data if allocated
if (t.data.dptr != nullptr) {
size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor);
const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream);
}
// Set amax to 0 if allocated
......@@ -440,6 +588,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(buf, &config_.noop_tensor, attr_size);
break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......@@ -472,6 +623,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(&config_.noop_tensor, buf, attr_size);
break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......
......@@ -348,15 +348,15 @@ void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine;
auto noop = Tensor();
transformer_engine::detail::cast_transpose(*reinterpret_cast<const Tensor *>(input), noop,
reinterpret_cast<Tensor *>(output), stream);
transformer_engine::detail::cast_transpose(*convertNVTETensorCheck(input), noop,
convertNVTETensor(output), stream);
}
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine;
transformer_engine::detail::cast_transpose(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(noop),
reinterpret_cast<Tensor *>(output), stream);
transformer_engine::detail::cast_transpose(*convertNVTETensorCheck(input),
*convertNVTETensorCheck(noop),
convertNVTETensor(output), stream);
}
......@@ -31,25 +31,27 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
// enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption {
// No rowwise data
// No rowwise data, skip rowwise quantization
NONE,
// Rowwise data, scales in GEMM format
ROWWISE
// TODO: FP8 all gather requires some changes.
// 1. Compact scales are better for gathering than the GEMM format.
ROWWISE_GEMM_READY,
// Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM
ROWWISE_COMPACT
};
// enum class for columnwise usage
// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling
enum class FP8BlockwiseColumnwiseOption {
// No columnwise data
// No columnwise data, skip columnwise quantization
NONE,
// Columnwise data transposed from original shape.
// Scales in GEMM format corresponding to GEMM ingesting transposed column data.
COLUMNWISE_TRANSPOSE
// TODO: FP8 all gather requires some changes.
// 1. The transpose gets in the way of the all gather.
// 2. Compact scales are better for gathering than the GEMM format.
// On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op
// On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose
COLUMNWISE_GEMM_READY,
// Columnwise data in original shape
// Scales in compact format, needs extra processing (padding, transposing) before GEMM
COLUMNWISE_COMPACT
};
void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
......
......@@ -17,6 +17,7 @@
#include "../util/string.h"
#include "../utils.cuh"
#include "cast_transpose.h"
#include "common/common.h"
namespace transformer_engine {
......@@ -196,17 +197,18 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace->data.dtype = DType::kFloat32;
} else {
// Check that workspace matches expected size
const size_t workspace_size =
const size_t workspace_size = get_buffer_size_bytes(
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) *
typeToSize(workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32);
std::multiplies<size_t>()),
workspace->data.dtype);
const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")");
", dtype=", typeToNumBits(workspace->data.dtype), " bits)");
}
}
......@@ -1337,9 +1339,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETe
constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(activation_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensor(activation_input),
convertNVTETensor(output), convertNVTETensor(dbias), convertNVTETensor(workspace), stream);
}
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
......@@ -1354,9 +1355,9 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor ac
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(act_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(act_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input,
......@@ -1371,9 +1372,9 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor si
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsilu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(silu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(silu_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input,
......@@ -1388,9 +1389,9 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor re
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(relu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(relu_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input,
......@@ -1405,9 +1406,9 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor s
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(srelu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(srelu_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input,
......@@ -1422,9 +1423,9 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor q
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(qgelu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(qgelu_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
......@@ -1434,8 +1435,8 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
convertNVTETensorCheck(output), stream);
}
void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input,
......@@ -1445,8 +1446,8 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dsilu<fp32, fp32>, silu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(swiglu_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(swiglu_input),
convertNVTETensorCheck(output), stream);
}
void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
......@@ -1456,8 +1457,8 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, drelu<fp32, fp32>, relu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
convertNVTETensorCheck(output), stream);
}
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
......@@ -1467,8 +1468,8 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dsrelu<fp32, fp32>, srelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
convertNVTETensorCheck(output), stream);
}
void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
......@@ -1478,6 +1479,6 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dqgelu<fp32, fp32>, qgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
convertNVTETensorCheck(output), stream);
}
......@@ -237,8 +237,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_store_size / typeToSize(otype);
const int tile_dim_n = THREADS_PER_WARP * desired_load_size / typeToSize(itype);
const int tile_dim_m = THREADS_PER_WARP * desired_store_size * 8 / typeToNumBits(otype);
const int tile_dim_n = THREADS_PER_WARP * desired_load_size * 8 / typeToNumBits(itype);
// Add tensors to kernel argument struct
MultiCastTransposeArgs kernel_args_aligned, kernel_args_unaligned;
......@@ -334,8 +334,8 @@ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i]));
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
}
multi_cast_transpose(input_list_, output_list_, stream);
}
......@@ -483,7 +483,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
CUtensorMap tensor_map_output_trans{};
create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x,
/*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM,
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType));
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType) * 8);
return tensor_map_output_trans;
}
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment