Unverified Commit a5ba71f3 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Move the amax/scale/scale_inv into the TE Tensor struct. (#33)



* Move the amax/scale/scale_inv into the TE Tensor struct.
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Handle multi_cast_transpose
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Changed softmax to new Tensor
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass at the cpp tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Round of fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix multi_cast_transpose
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cast_to_fp8
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 509bf877
...@@ -24,19 +24,13 @@ extern "C" { ...@@ -24,19 +24,13 @@ extern "C" {
* - `transposed_output` is the transposed result of the cast. * - `transposed_output` is the transposed result of the cast.
* *
* \param[in] input Input tensor of shape [N, H]. * \param[in] input Input tensor of shape [N, H].
* \param[in] scale Scaling factor used for outputs. * \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[out] cast_output Result of the cast. Shape: [N, H]. * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cast_transpose(const NVTETensor input, void nvte_cast_transpose(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Transpose the input. /*! \brief Transpose the input.
...@@ -60,23 +54,17 @@ void nvte_transpose(const NVTETensor input, ...@@ -60,23 +54,17 @@ void nvte_transpose(const NVTETensor input,
* but instead set the shape and type of the workspace tensor to the required values. * but instead set the shape and type of the workspace tensor to the required values.
* *
* \param[in] input Input tensor of shape [N, H]. * \param[in] input Input tensor of shape [N, H].
* \param[in] scale Scaling factor used for outputs. * \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[out] cast_output Result of the cast. Shape: [N, H]. * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] dbias Result of the reduction of the input along the * \param[out] dbias Result of the reduction of the input along the
* first dimension. Shape: [H]. * first dimension. Shape: [H].
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cast_transpose_dbias(const NVTETensor input, void nvte_cast_transpose_dbias(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias, NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
...@@ -94,24 +82,18 @@ void nvte_cast_transpose_dbias(const NVTETensor input, ...@@ -94,24 +82,18 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
* \param[in] input Input tensor of shape [N, H]. * \param[in] input Input tensor of shape [N, H].
* \param[in] gelu_input Tensor used as input to the forward of GELU operation. * \param[in] gelu_input Tensor used as input to the forward of GELU operation.
* Shape [N, H]. * Shape [N, H].
* \param[in] scale Scaling factor used for outputs. * \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[out] cast_output Result of the cast. Shape: [N, H]. * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] dbias Result of the reduction of the dGELU(input) along the * \param[out] dbias Result of the reduction of the dGELU(input) along the
* first dimension. Shape: [H]. * first dimension. Shape: [H].
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input, const NVTETensor gelu_input,
const NVTETensor scale,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias, NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
...@@ -123,23 +105,17 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -123,23 +105,17 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
* *
* \param[in] num_tensors Number of tensors. * \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D input tensors. * \param[in] input_list List of 2D input tensors.
* \param[in] scale_list Scaling factor to generate outputs. * \param[in,out] cast_output_list List of casted tensors. Dimensions
* \param[out] cast_output_list List of casted tensors. Dimensions
* match tensors in input_list. * match tensors in input_list.
* \param[out] transposed_output_list List of casted and transposed * \param[in,out] transposed_output_list List of casted and transposed
* tensors. Dimensions are transpose * tensors. Dimensions are transpose
* of tensors in input_list. * of tensors in input_list.
* \param[in,out] amax_list AMAX values of the output tensors.
* \param[out] scale_inv_list Inverses of the scaling factors.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_multi_cast_transpose(size_t num_tensors, void nvte_multi_cast_transpose(size_t num_tensors,
const NVTETensor* input_list, const NVTETensor* input_list,
const NVTETensor* scale_list,
NVTETensor* cast_output_list, NVTETensor* cast_output_list,
NVTETensor* transposed_output_list, NVTETensor* transposed_output_list,
NVTETensor* amax_list,
NVTETensor* scale_inv_list,
cudaStream_t stream); cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -141,10 +141,64 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS; ...@@ -141,10 +141,64 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
template<>
struct TypeId<fp8e4m3>{
constexpr static uint32_t Value = 3;
};
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar{ struct FwdTunedRegistrar{
explicit FwdTunedRegistrar(FwdFunction f){ explicit FwdTunedRegistrar(FwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(HIDDEN_SIZE); uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({ key, f }); FWD_TUNED_FUNCS.insert({ key, f });
} }
}; };
...@@ -154,7 +208,7 @@ struct FwdTunedRegistrar{ ...@@ -154,7 +208,7 @@ struct FwdTunedRegistrar{
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar{ struct FwdGeneralRegistrar{
explicit FwdGeneralRegistrar(FwdFunction f){ explicit FwdGeneralRegistrar(FwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(0); uint64_t key = Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f }); FWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f });
} }
}; };
...@@ -164,7 +218,7 @@ struct FwdGeneralRegistrar{ ...@@ -164,7 +218,7 @@ struct FwdGeneralRegistrar{
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar{ struct BwdTunedRegistrar{
explicit BwdTunedRegistrar(BwdFunction f){ explicit BwdTunedRegistrar(BwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(HIDDEN_SIZE); uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({ key, f }); BWD_TUNED_FUNCS.insert({ key, f });
} }
}; };
...@@ -174,7 +228,7 @@ struct BwdTunedRegistrar{ ...@@ -174,7 +228,7 @@ struct BwdTunedRegistrar{
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE> template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar{ struct BwdGeneralRegistrar{
explicit BwdGeneralRegistrar(BwdFunction f){ explicit BwdGeneralRegistrar(BwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(0); uint64_t key = Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f }); BWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f });
} }
}; };
...@@ -187,6 +241,8 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype, ...@@ -187,6 +241,8 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType ctype, DType ctype,
uint32_t hidden_size); uint32_t hidden_size);
//////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm } // namespace layer_norm
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -144,7 +144,6 @@ size_t product(const std::vector<size_t> &shape) { ...@@ -144,7 +144,6 @@ size_t product(const std::vector<size_t> &shape) {
void layernorm_fwd(const Tensor& x, // BxSxhidden_size void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const Tensor& gamma, // hidden_size const Tensor& gamma, // hidden_size
const Tensor& beta, // hidden_size const Tensor& beta, // hidden_size
const Tensor& scale,
const float epsilon, const float epsilon,
Tensor* z, Tensor* z,
Tensor* mu, Tensor* mu,
...@@ -152,50 +151,39 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -152,50 +151,39 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
cudaStream_t stream, cudaStream_t stream,
const int multiprocessorCount, const int multiprocessorCount,
Tensor* workspace, Tensor* workspace,
Tensor* barrier, Tensor* barrier) {
Tensor* amax, const auto itype = x.data.dtype;
Tensor *scale_inv const auto wtype = gamma.data.dtype;
) { const auto otype = z->data.dtype;
auto itype = x.dtype; const bool fp8_out = is_fp8_dtype(otype);
auto wtype = gamma.dtype; const auto ctype = layer_norm::DType::kFloat32;
auto otype = z->dtype;
bool fp8_out = otype == DType::kFloat8E4M3 ||
otype == DType::kFloat8E5M2;
auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.shape.size() == 2); CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
const size_t rows = x.shape[0]; CheckOutputTensor(*z, "z");
const size_t cols = x.shape[1]; CheckOutputTensor(*mu, "mu");
auto hidden_size = gamma.shape[0]; CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(gamma.shape == beta.shape); NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(epsilon >= 0.f); const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(z->dptr != nullptr); NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(z->shape == x.shape); NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(mu->shape == std::vector<size_t>{ rows });
NVTE_CHECK(mu->dtype == ctype);
NVTE_CHECK(rsigma->shape == std::vector<size_t>{ rows }); NVTE_CHECK(epsilon >= 0.f);
NVTE_CHECK(rsigma->dtype == ctype);
if (fp8_out) { NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(scale.dptr != nullptr);
NVTE_CHECK(scale.dtype == ctype);
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }); NVTE_CHECK(mu->data.shape == std::vector<size_t>{ rows });
NVTE_CHECK(amax->dptr != nullptr); NVTE_CHECK(mu->data.dtype == ctype);
NVTE_CHECK(amax->dtype == ctype);
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 }); NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{ rows });
NVTE_CHECK(scale_inv->dptr != nullptr); NVTE_CHECK(rsigma->data.dtype == ctype);
NVTE_CHECK(scale_inv->dtype == ctype);
}
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params; layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
...@@ -210,49 +198,49 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -210,49 +198,49 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
layer_norm::FwdParams &params = launch_params.params; layer_norm::FwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
params.cols = cols; params.cols = cols;
params.x = x.dptr; params.x = x.data.dptr;
params.mu = mu->dptr; params.mu = mu->data.dptr;
params.rs = rsigma->dptr; params.rs = rsigma->data.dptr;
params.gamma = gamma.dptr; params.gamma = gamma.data.dptr;
params.beta = beta.dptr; params.beta = beta.data.dptr;
params.z = z->dptr; params.z = z->data.dptr;
params.epsilon = epsilon; params.epsilon = epsilon;
params.amax = amax->dptr; params.amax = z->amax.dptr;
params.scale = scale.dptr; params.scale = z->scale.dptr;
params.scale_inv = scale_inv->dptr; params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out; params.fp8_out = fp8_out;
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
if (workspace->dptr == nullptr) { if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->dptr == nullptr); NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->dtype = layer_norm::DType::kByte; workspace->data.dtype = layer_norm::DType::kByte;
if (launch_params.workspace_bytes == 0) { if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1; launch_params.workspace_bytes = 1;
} }
workspace->shape = { launch_params.workspace_bytes }; workspace->data.shape = { launch_params.workspace_bytes };
barrier->dtype = layer_norm::DType::kInt32; barrier->data.dtype = layer_norm::DType::kInt32;
barrier->shape = { launch_params.barrier_size }; barrier->data.shape = { launch_params.barrier_size };
return; return;
} }
if ( launch_params.barrier_size > 0 ) { if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->dptr); params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
} }
// Clear buffers // Clear buffers
if ( params.fp8_out ) { if ( params.fp8_out ) {
cudaMemsetAsync(params.amax, 0, cudaMemsetAsync(params.amax, 0,
layer_norm::product(amax->shape) * layer_norm::product(z->amax.shape) *
typeToSize(amax->dtype), stream); typeToSize(z->amax.dtype), stream);
} }
if ( launch_params.barrier_size > 0 ) { if ( launch_params.barrier_size > 0 ) {
cudaMemsetAsync(params.barrier, 0, cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->shape) * layer_norm::product(barrier->data.shape) *
typeToSize(barrier->dtype), stream); typeToSize(barrier->data.dtype), stream);
} }
// Launch the kernel. // Launch the kernel.
...@@ -278,38 +266,44 @@ void layernorm_bwd(const Tensor& dz, ...@@ -278,38 +266,44 @@ void layernorm_bwd(const Tensor& dz,
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
auto itype = x.dtype; auto itype = x.data.dtype;
auto wtype = gamma.dtype; auto wtype = gamma.data.dtype;
auto otype = wtype; auto otype = wtype;
auto ctype = DType::kFloat32; auto ctype = DType::kFloat32;
NVTE_CHECK(dz.dtype == otype); CheckInputTensor(dz, "dz");
NVTE_CHECK(mu.dtype == ctype); CheckInputTensor(x, "x");
NVTE_CHECK(rsigma.dtype == ctype); CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.shape.size() == 2); NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.shape == x.shape); NVTE_CHECK(dz.data.shape == x.data.shape);
auto rows = x.shape[0]; auto rows = x.data.shape[0];
auto cols = x.shape[1]; auto cols = x.data.shape[1];
auto hidden_size = gamma.shape[0]; auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(mu.shape[0] == rows); NVTE_CHECK(mu.data.shape[0] == rows);
NVTE_CHECK(mu.shape == rsigma.shape); NVTE_CHECK(mu.data.shape == rsigma.data.shape);
NVTE_CHECK(gamma.shape[0] == cols); NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->shape == x.shape); NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->dtype == x.dtype); NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dx->dptr != nullptr);
NVTE_CHECK(dgamma->shape == gamma.shape); NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->dtype == gamma.dtype); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
NVTE_CHECK(dgamma->dptr != nullptr);
NVTE_CHECK(dbeta->shape == gamma.shape); NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->dtype == gamma.dtype); NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
NVTE_CHECK(dbeta->dptr != nullptr);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params; layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = stream; launch_params.stream = stream;
...@@ -322,47 +316,47 @@ void layernorm_bwd(const Tensor& dz, ...@@ -322,47 +316,47 @@ void layernorm_bwd(const Tensor& dz,
layer_norm::BwdParams &params = launch_params.params; layer_norm::BwdParams &params = launch_params.params;
params.rows = rows; params.rows = rows;
params.cols = cols; params.cols = cols;
params.x = x.dptr; params.x = x.data.dptr;
params.mu = mu.dptr; params.mu = mu.data.dptr;
params.rs = rsigma.dptr; params.rs = rsigma.data.dptr;
params.gamma = gamma.dptr; params.gamma = gamma.data.dptr;
params.dz = dz.dptr; params.dz = dz.data.dptr;
params.dx = dx->dptr; params.dx = dx->data.dptr;
params.dbeta = dbeta->dptr; params.dbeta = dbeta->data.dptr;
params.dgamma = dgamma->dptr; params.dgamma = dgamma->data.dptr;
params.dbeta_part = dbeta_part->dptr; params.dbeta_part = dbeta_part->data.dptr;
params.dgamma_part = dgamma_part->dptr; params.dgamma_part = dgamma_part->data.dptr;
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory // Populate shape and dtypes for FW to allocate memory
if (dgamma_part->dptr == nullptr) { if (dgamma_part->data.dptr == nullptr) {
NVTE_CHECK(dbeta_part->dptr == nullptr); NVTE_CHECK(dbeta_part->data.dptr == nullptr);
dgamma_part->dtype = ctype; dgamma_part->data.dtype = ctype;
dgamma_part->shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col), dgamma_part->data.shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size }; hidden_size };
dbeta_part->dtype = ctype; dbeta_part->data.dtype = ctype;
dbeta_part->shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col), dbeta_part->data.shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size }; hidden_size };
workspace->dtype = layer_norm::DType::kByte; workspace->data.dtype = layer_norm::DType::kByte;
workspace->shape = { launch_params.workspace_bytes }; workspace->data.shape = { launch_params.workspace_bytes };
barrier->dtype = layer_norm::DType::kInt32; barrier->data.dtype = layer_norm::DType::kInt32;
barrier->shape = { launch_params.barrier_size }; barrier->data.shape = { launch_params.barrier_size };
return; return;
} }
if ( launch_params.barrier_size > 0 ) { if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->dptr; params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->dptr); params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0, cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->shape) * layer_norm::product(barrier->data.shape) *
typeToSize(barrier->dtype), stream); typeToSize(barrier->data.dtype), stream);
} }
// Launch the kernel. // Launch the kernel.
...@@ -373,7 +367,6 @@ void layernorm_bwd(const Tensor& dz, ...@@ -373,7 +367,6 @@ void layernorm_bwd(const Tensor& dz,
void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size const NVTETensor beta, // hidden_size
const NVTETensor scale, // 1
const float epsilon, const float epsilon,
NVTETensor z, NVTETensor z,
NVTETensor mu, NVTETensor mu,
...@@ -381,14 +374,11 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -381,14 +374,11 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
cudaStream_t stream, cudaStream_t stream,
const int multiprocessorCount, const int multiprocessorCount,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier, NVTETensor barrier) {
NVTETensor amax,
NVTETensor scale_inv) {
using namespace transformer_engine; using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta), *reinterpret_cast<const Tensor*>(beta),
*reinterpret_cast<const Tensor*>(scale),
epsilon, epsilon,
reinterpret_cast<Tensor*>(z), reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(mu),
...@@ -396,9 +386,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -396,9 +386,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
stream, stream,
multiprocessorCount, multiprocessorCount,
reinterpret_cast<Tensor*>(workspace), reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier), reinterpret_cast<Tensor*>(barrier));
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv));
} }
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
......
...@@ -15,15 +15,77 @@ size_t typeToSize(const transformer_engine::DType type) { ...@@ -15,15 +15,77 @@ size_t typeToSize(const transformer_engine::DType type) {
); // NOLINT(*) ); // NOLINT(*)
} }
bool is_fp8_dtype(const transformer_engine::DType t) {
return t == transformer_engine::DType::kFloat8E4M3 ||
t == transformer_engine::DType::kFloat8E5M2;
}
void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.data.dtype;
if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv
NVTE_CHECK(t.scale_inv.dptr != nullptr,
"FP8 input " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 });
} else {
NVTE_CHECK(t.scale.dptr == nullptr,
"Scale is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input " + name + ".");
}
NVTE_CHECK(t.data.dptr != nullptr,
"Input " + name + " is not allocated!");
}
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.data.dtype;
if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, amax and scale_inv
NVTE_CHECK(t.amax.dptr != nullptr,
"FP8 output " + name + " must have amax tensor.");
NVTE_CHECK(t.amax.dtype == DType::kFloat32);
NVTE_CHECK(t.amax.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(t.scale_inv.dptr != nullptr,
"FP8 output " + name + " must have scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(t.scale.dptr != nullptr,
"FP8 output " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale.dtype == DType::kFloat32);
NVTE_CHECK(t.scale.shape == std::vector<size_t>{ 1 });
} else {
NVTE_CHECK(t.scale.dptr == nullptr,
"Scale is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 output " + name + ".");
}
if (!allow_empty) {
NVTE_CHECK(t.data.dptr != nullptr,
"Output " + name + " is not allocated!");
}
}
} // namespace transformer_engine } // namespace transformer_engine
NVTETensor nvte_create_tensor(void *dptr, NVTETensor nvte_create_tensor(void *dptr,
const NVTEShape shape, const NVTEShape shape,
const NVTEDType dtype) { const NVTEDType dtype,
float *amax,
float *scale,
float *scale_inv) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor; transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->dptr = dptr; ret->data.dptr = dptr;
ret->shape = std::vector<size_t>(shape.data, shape.data + shape.ndim); ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
ret->dtype = static_cast<transformer_engine::DType>(dtype); ret->data.dtype = static_cast<transformer_engine::DType>(dtype);
ret->amax.dptr = amax;
ret->scale.dptr = scale;
ret->scale_inv.dptr = scale_inv;
return ret; return ret;
} }
...@@ -34,18 +96,40 @@ void nvte_destroy_tensor(NVTETensor tensor) { ...@@ -34,18 +96,40 @@ void nvte_destroy_tensor(NVTETensor tensor) {
} }
NVTEDType nvte_tensor_type(const NVTETensor tensor) { NVTEDType nvte_tensor_type(const NVTETensor tensor) {
return static_cast<NVTEDType>(reinterpret_cast<const transformer_engine::Tensor*>(tensor)->dtype); return static_cast<NVTEDType>(
reinterpret_cast<const transformer_engine::Tensor*>(tensor)->data.dtype);
} }
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTEShape ret; NVTEShape ret;
ret.data = t.shape.data(); ret.data = t.data.shape.data();
ret.ndim = t.shape.size(); ret.ndim = t.data.shape.size();
return ret; return ret;
} }
void *nvte_tensor_data(const NVTETensor tensor) { void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
return t.dptr; return t.data.dptr;
}
float *nvte_tensor_amax(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!");
return reinterpret_cast<float*>(t.amax.dptr);
}
float *nvte_tensor_scale(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale.dptr);
}
float *nvte_tensor_scale_inv(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
"Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale_inv.dptr);
} }
...@@ -105,7 +105,7 @@ cast_transpose_kernel(const IType * const input, ...@@ -105,7 +105,7 @@ cast_transpose_kernel(const IType * const input,
warp_id_in_tile * n_iterations) % warp_id_in_tile * n_iterations) %
THREADS_PER_WARP; THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = *scale_ptr; const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
...@@ -158,8 +158,8 @@ cast_transpose_kernel(const IType * const input, ...@@ -158,8 +158,8 @@ cast_transpose_kernel(const IType * const input,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, max); if (amax != nullptr) atomicMaxFloat(amax, max);
reciprocal<float>(scale_inv, scale); if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
} }
} }
...@@ -222,7 +222,7 @@ cast_transpose_kernel_notaligned(const IType * const input, ...@@ -222,7 +222,7 @@ cast_transpose_kernel_notaligned(const IType * const input,
warp_id_in_tile * n_iterations) % warp_id_in_tile * n_iterations) %
THREADS_PER_WARP; THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = *scale_ptr; const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
{ {
const bool valid_load = my_place < tile_length && const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height; warp_id_in_tile * n_iterations < tile_height;
...@@ -294,48 +294,41 @@ cast_transpose_kernel_notaligned(const IType * const input, ...@@ -294,48 +294,41 @@ cast_transpose_kernel_notaligned(const IType * const input,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, max); if (amax != nullptr) atomicMaxFloat(amax, max);
reciprocal<float>(scale_inv, scale); if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
} }
} }
void cast_transpose(const Tensor &input, void cast_transpose(const Tensor &input,
const Tensor &scale,
Tensor *cast_output, Tensor *cast_output,
Tensor *transposed_output, Tensor *transposed_output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions."); CheckInputTensor(input, "cast_transpose_input");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions."); CheckOutputTensor(*cast_output, "cast_output");
NVTE_CHECK(transposed_output->shape.size() == 2, "T output must have 2 dimensions."); CheckOutputTensor(*transposed_output, "transposed_output");
NVTE_CHECK(input.shape == cast_output->shape, "Input and C output must have the same shape.");
const size_t row_length = input.shape[1]; NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
const size_t num_rows = input.shape[0]; NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output."); NVTE_CHECK(input.data.shape == cast_output->data.shape,
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output."); "Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
NVTE_CHECK(cast_output->dtype == transposed_output->dtype, const size_t num_rows = input.data.shape[0];
"Both C and T outputs need to have the same type.");
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element."); NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 }, NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"scale_inv tensor must have 1 element."); "C and T outputs need to have the same type.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type."); NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element."); "C and T outputs need to share amax tensor.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type."); NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated."); "C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated."); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
constexpr int itype_size = sizeof(InputType); constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType); constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size / itype_size; constexpr int nvec_in = desired_load_size / itype_size;
...@@ -363,12 +356,12 @@ void cast_transpose(const Tensor &input, ...@@ -363,12 +356,12 @@ void cast_transpose(const Tensor &input,
cast_transpose_num_threads / n_warps_per_tile * cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), (THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>( stream>>>(
reinterpret_cast<const InputType *>(input.dptr), reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->dptr), reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->dptr), reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(scale.dptr), reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(amax->dptr), reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(scale_inv->dptr), reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles); row_length, num_rows, n_tiles);
} else { } else {
cudaFuncSetAttribute(cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32, cudaFuncSetAttribute(cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32,
...@@ -381,12 +374,12 @@ void cast_transpose(const Tensor &input, ...@@ -381,12 +374,12 @@ void cast_transpose(const Tensor &input,
cast_transpose_num_threads / n_warps_per_tile * cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>), (THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>( stream>>>(
reinterpret_cast<const InputType *>(input.dptr), reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->dptr), reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->dptr), reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(scale.dptr), reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(amax->dptr), reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(scale_inv->dptr), reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles); row_length, num_rows, n_tiles);
} }
); // NOLINT(*) ); // NOLINT(*)
...@@ -396,18 +389,12 @@ void cast_transpose(const Tensor &input, ...@@ -396,18 +389,12 @@ void cast_transpose(const Tensor &input,
} // namespace transformer_engine } // namespace transformer_engine
void nvte_cast_transpose(const NVTETensor input, void nvte_cast_transpose(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream); stream);
} }
...@@ -162,7 +162,7 @@ cast_transpose_dbias_kernel(const Param param, ...@@ -162,7 +162,7 @@ cast_transpose_dbias_kernel(const Param param,
warp_id_in_tile * n_iterations) % warp_id_in_tile * n_iterations) %
THREADS_PER_WARP; THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = *param.scale_ptr; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear(); partial_dbias.clear();
...@@ -240,8 +240,8 @@ cast_transpose_dbias_kernel(const Param param, ...@@ -240,8 +240,8 @@ cast_transpose_dbias_kernel(const Param param,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max); if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale); if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
} }
} }
...@@ -310,7 +310,7 @@ cast_transpose_dbias_kernel_notaligned(const Param param, ...@@ -310,7 +310,7 @@ cast_transpose_dbias_kernel_notaligned(const Param param,
warp_id_in_tile * n_iterations) % warp_id_in_tile * n_iterations) %
THREADS_PER_WARP; THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = *param.scale_ptr; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear(); partial_dbias.clear();
...@@ -409,8 +409,8 @@ cast_transpose_dbias_kernel_notaligned(const Param param, ...@@ -409,8 +409,8 @@ cast_transpose_dbias_kernel_notaligned(const Param param,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max); if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale); if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
} }
} }
...@@ -456,16 +456,16 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -456,16 +456,16 @@ reduce_dbias_kernel(OutputType* const dbias_output,
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor* workspace, Tensor* workspace,
const int nvec_out) { const int nvec_out) {
const size_t row_length = cast_output.shape[1]; const size_t row_length = cast_output.data.shape[1];
const size_t num_rows = cast_output.shape[0]; const size_t num_rows = cast_output.data.shape[0];
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
workspace->shape = {num_rows_partial_dbias, row_length}; workspace->data.shape = {num_rows_partial_dbias, row_length};
workspace->dtype = DType::kFloat32; workspace->data.dtype = DType::kFloat32;
} }
template <typename InputType> template <typename InputType>
...@@ -489,61 +489,54 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, ...@@ -489,61 +489,54 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias,
reduce_dbias_num_threads, reduce_dbias_num_threads,
0, 0,
stream>>>( stream>>>(
reinterpret_cast<InputType *>(dbias->dptr), reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.dptr), reinterpret_cast<const fp32 *>(workspace.data.dptr),
reduce_dbias_row_length, reduce_dbias_row_length,
reduce_dbias_num_rows); reduce_dbias_num_rows);
} }
void cast_transpose_dbias(const Tensor &input, void cast_transpose_dbias(const Tensor &input,
const Tensor &scale,
Tensor *cast_output, Tensor *cast_output,
Tensor *transposed_output, Tensor *transposed_output,
Tensor *amax,
Tensor *dbias, Tensor *dbias,
Tensor *scale_inv,
Tensor *workspace, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions."); CheckInputTensor(input, "cast_transpose_dbias_input");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions."); CheckOutputTensor(*cast_output, "cast_output");
NVTE_CHECK(transposed_output->shape.size() == 2, "T output must have 2 dimensions."); CheckOutputTensor(*transposed_output, "transposed_output");
NVTE_CHECK(input.shape == cast_output->shape, "Input and C output must have the same shape."); CheckOutputTensor(*dbias, "dbias");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0]; NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output."); NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
NVTE_CHECK(cast_output->dtype == transposed_output->dtype, const size_t row_length = input.data.shape[1];
"Both T and C outputs need to have the same type."); const size_t num_rows = input.data.shape[0];
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element."); NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type."); NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element."); NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type."); "C and T outputs need to have the same type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element."); NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type."); "C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated."); "C and T outputs need to share scale tensor.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated."); NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated."); "C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated."); NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
NVTE_CHECK(dbias->dptr != nullptr, "DBias is not allocated."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
NVTE_CHECK(dbias->dtype == input.dtype, "DBias must have the same type as input."); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
NVTE_CHECK(dbias->shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
constexpr int itype_size = sizeof(InputType); constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType); constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size / itype_size; constexpr int nvec_in = desired_load_size / itype_size;
constexpr int nvec_out = desired_store_size / otype_size; constexpr int nvec_out = desired_store_size / otype_size;
if (workspace->dptr == nullptr) { if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out); populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return; return;
} }
...@@ -567,13 +560,13 @@ void cast_transpose_dbias(const Tensor &input, ...@@ -567,13 +560,13 @@ void cast_transpose_dbias(const Tensor &input,
static_assert(shared_size_transpose >= shared_size_dbias); static_assert(shared_size_transpose >= shared_size_dbias);
using Param = CTDBiasParam<InputType, OutputType, ComputeType>; using Param = CTDBiasParam<InputType, OutputType, ComputeType>;
Param param; Param param;
param.input = reinterpret_cast<const InputType *>(input.dptr); param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->dptr); param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->dptr); param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(scale.dptr); param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(amax->dptr); param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(scale_inv->dptr); param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_kernel<nvec_in, nvec_out, Param>, cudaFuncSetAttribute(cast_transpose_dbias_kernel<nvec_in, nvec_out, Param>,
...@@ -678,7 +671,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param, ...@@ -678,7 +671,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
warp_id_in_tile * n_iterations) % warp_id_in_tile * n_iterations) %
THREADS_PER_WARP; THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = *param.scale_ptr; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear(); partial_dbias.clear();
...@@ -769,8 +762,8 @@ cast_transpose_dbias_dgelu_kernel(const Param param, ...@@ -769,8 +762,8 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max); if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale); if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
} }
} }
...@@ -846,7 +839,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, ...@@ -846,7 +839,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
warp_id_in_tile * n_iterations) % warp_id_in_tile * n_iterations) %
THREADS_PER_WARP; THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = *param.scale_ptr; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear(); partial_dbias.clear();
...@@ -960,61 +953,53 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, ...@@ -960,61 +953,53 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max); if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale); if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
} }
} }
void cast_transpose_dbias_dgelu(const Tensor &input, void cast_transpose_dbias_dgelu(const Tensor &input,
const Tensor &gelu_input, const Tensor &gelu_input,
const Tensor &scale,
Tensor *cast_output, Tensor *cast_output,
Tensor *transposed_output, Tensor *transposed_output,
Tensor *amax,
Tensor *dbias, Tensor *dbias,
Tensor *scale_inv,
Tensor *workspace, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions."); CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions."); CheckInputTensor(gelu_input, "gelu_input");
NVTE_CHECK(transposed_output->shape.size() == 2, CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2,
"T output must have 2 dimensions."); "T output must have 2 dimensions.");
NVTE_CHECK(input.shape == cast_output->shape, NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape."); "Input and C output must have the same shape.");
const size_t row_length = input.shape[1]; const size_t row_length = input.data.shape[1];
const size_t num_rows = input.shape[0]; const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output."); NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output."); NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->dtype == transposed_output->dtype, NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"Both C and T outputs need to have the same type."); "C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element."); "C and T outputs need to share amax tensor.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type."); NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 }, "C and T outputs need to share scale tensor.");
"scale_inv tensor must have 1 element."); NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type."); "C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type."); NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(gelu_input.dptr != nullptr, "GeLU input is not allocated."); NVTE_CHECK(input.data.dtype == gelu_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated."); NVTE_CHECK(input.data.shape == gelu_input.data.shape, "Shapes of both inputs must match.");
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated."); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated.");
NVTE_CHECK(dbias->dptr != nullptr, "DBias is not allocated.");
NVTE_CHECK(dbias->dtype == input.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
NVTE_CHECK(input.dtype == gelu_input.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.shape == gelu_input.shape, "Shapes of both inputs must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
using InputType2 = InputType; using InputType2 = InputType;
/* dgelu fusion kernel uses more registers */ /* dgelu fusion kernel uses more registers */
constexpr int desired_load_size_dgelu = 4; constexpr int desired_load_size_dgelu = 4;
...@@ -1024,7 +1009,7 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1024,7 +1009,7 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
constexpr int nvec_in = desired_load_size_dgelu / itype_size; constexpr int nvec_in = desired_load_size_dgelu / itype_size;
constexpr int nvec_out = desired_store_size_dgelu / otype_size; constexpr int nvec_out = desired_store_size_dgelu / otype_size;
if (workspace->dptr == nullptr) { if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out); populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return; return;
} }
...@@ -1048,14 +1033,14 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1048,14 +1033,14 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
static_assert(shared_size_transpose >= shared_size_dbias); static_assert(shared_size_transpose >= shared_size_dbias);
using Param = CTDBiasDGeluParam<InputType, InputType2, OutputType, ComputeType>; using Param = CTDBiasDGeluParam<InputType, InputType2, OutputType, ComputeType>;
Param param; Param param;
param.input = reinterpret_cast<const InputType *>(input.dptr); param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.gelu_input = reinterpret_cast<const InputType2 *>(gelu_input.dptr); param.gelu_input = reinterpret_cast<const InputType2 *>(gelu_input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->dptr); param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->dptr); param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(scale.dptr); param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(amax->dptr); param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(scale_inv->dptr); param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>, cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout,
...@@ -1084,45 +1069,33 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ...@@ -1084,45 +1069,33 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
} // namespace transformer_engine } // namespace transformer_engine
void nvte_cast_transpose_dbias(const NVTETensor input, void nvte_cast_transpose_dbias(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias, NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input), cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(dbias), reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(workspace), reinterpret_cast<Tensor*>(workspace),
stream); stream);
} }
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input, const NVTETensor gelu_input,
const NVTETensor scale,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias, NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input), cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input), *reinterpret_cast<const Tensor*>(gelu_input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(dbias), reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(workspace), reinterpret_cast<Tensor*>(workspace),
stream); stream);
} }
...@@ -87,7 +87,8 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -87,7 +87,8 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]); const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]);
OType* output_c = reinterpret_cast<OType*>(args.output_c_list[tensor_id]); OType* output_c = reinterpret_cast<OType*>(args.output_c_list[tensor_id]);
OType* output_t = reinterpret_cast<OType*>(args.output_t_list[tensor_id]); OType* output_t = reinterpret_cast<OType*>(args.output_t_list[tensor_id]);
const CType scale = *reinterpret_cast<CType*>(args.scale_list[tensor_id]); const CType* scale_ptr = reinterpret_cast<CType*>(args.scale_list[tensor_id]);
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]); CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]);
CType* scale_inv = reinterpret_cast<CType*>(args.scale_inv_list[tensor_id]); CType* scale_inv = reinterpret_cast<CType*>(args.scale_inv_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id]; const int num_rows = args.num_rows_list[tensor_id];
...@@ -190,79 +191,56 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -190,79 +191,56 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
local_amax = reduce_max<n_warps_per_tile>(local_amax, tidy); local_amax = reduce_max<n_warps_per_tile>(local_amax, tidy);
if (tid == 0) { if (tid == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, local_amax); if (amax != nullptr) atomicMaxFloat(amax, local_amax);
} }
if (tid == 0 && tile_id == 0) { if (tid == 0 && tile_id == 0) {
reciprocal<float>(scale_inv, scale); if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
} }
} }
} // namespace } // namespace
void multi_cast_transpose(const std::vector<Tensor*> input_list, void multi_cast_transpose(const std::vector<Tensor*> input_list,
const std::vector<Tensor*> scale_list,
std::vector<Tensor*> cast_output_list, std::vector<Tensor*> cast_output_list,
std::vector<Tensor*> transposed_output_list, std::vector<Tensor*> transposed_output_list,
std::vector<Tensor*> amax_list,
std::vector<Tensor*> scale_inv_list,
cudaStream_t stream) { cudaStream_t stream) {
// Check that number of tensors is valid // Check that number of tensors is valid
NVTE_CHECK(scale_list.size() == input_list.size(),
"Number of input and scale tensors must match");
NVTE_CHECK(cast_output_list.size() == input_list.size(), NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match"); "Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(), NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match"); "Number of input and T output tensors must match");
NVTE_CHECK(amax_list.size() == input_list.size(),
"Number of input and AMAX tensors must match");
NVTE_CHECK(scale_inv_list.size() == input_list.size(),
"Number of input and scale_inv tensors must match");
if (input_list.empty()) { if (input_list.empty()) {
return; return;
} }
// Check that tensor properties are valid // Check that tensor properties are valid
DType ctype = DType::kFloat32; DType itype = input_list[0]->data.dtype;
DType itype = input_list[0]->dtype; DType otype = cast_output_list[0]->data.dtype;
DType otype = cast_output_list[0]->dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id]; const auto& input = *input_list[tensor_id];
const auto& scale = *scale_list[tensor_id];
const auto& cast_output = *cast_output_list[tensor_id]; const auto& cast_output = *cast_output_list[tensor_id];
const auto& transposed_output = *transposed_output_list[tensor_id]; const auto& transposed_output = *transposed_output_list[tensor_id];
const auto& amax = *amax_list[tensor_id]; CheckInputTensor(input, "multi_cast_transpose_input_" + std::to_string(tensor_id));
const auto& scale_inv = *scale_inv_list[tensor_id]; CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id));
CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.dtype == itype, NVTE_CHECK(input.data.dtype == itype,
"Input tensor types do not match."); "Input tensor types do not match.");
NVTE_CHECK(scale.dtype == ctype, NVTE_CHECK(cast_output.data.dtype == otype,
"Scale tensor must have Float32 type.");
NVTE_CHECK(cast_output.dtype == otype,
"C output tensor types do not match."); "C output tensor types do not match.");
NVTE_CHECK(transposed_output.dtype == otype, NVTE_CHECK(transposed_output.data.dtype == otype,
"T output tensor types do not match."); "T output tensor types do not match.");
NVTE_CHECK(amax.dtype == ctype,
"AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv.dtype == ctype,
"scale_inv tensor must have Float32 type.");
NVTE_CHECK(input.shape.size() == 2, NVTE_CHECK(input.data.shape.size() == 2,
"Input tensor must have 2 dimensions."); "Input tensor must have 2 dimensions.");
NVTE_CHECK(cast_output.shape == input.shape, NVTE_CHECK(cast_output.data.shape == input.data.shape,
"C output tensor shape does not match input tensor."); "C output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape.size() == 2, NVTE_CHECK(transposed_output.data.shape.size() == 2,
"T output tensor shape does not match input tensor."); "T output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape[0] == input.shape[1], NVTE_CHECK(transposed_output.data.shape[0] == input.data.shape[1],
"T output tensor shape does not match input tensor."); "T output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape[1] == input.shape[0], NVTE_CHECK(transposed_output.data.shape[1] == input.data.shape[0],
"T output tensor shape does not match input tensor."); "T output tensor shape does not match input tensor.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(cast_output.dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(transposed_output.dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(amax.dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv.dptr != nullptr, "scale_inv output is not allocated.");
} }
// Input matrices are divided into tiles // Input matrices are divided into tiles
...@@ -304,8 +282,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -304,8 +282,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
} }
// Calculate number of thread blocks needed for tensor // Calculate number of thread blocks needed for tensor
const int num_rows = input_list[tensor_id]->shape[0]; const int num_rows = input_list[tensor_id]->data.shape[0];
const int row_length = input_list[tensor_id]->shape[1]; const int row_length = input_list[tensor_id]->data.shape[1];
const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m; const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n; const int num_tiles = num_tiles_m * num_tiles_n;
...@@ -317,12 +295,12 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -317,12 +295,12 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
// Add tensor to kernel argument struct // Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors; const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->dptr); kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->dptr; kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->data.dptr;
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->dptr; kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr;
kernel_args.scale_list[pos] = const_cast<void*>(scale_list[tensor_id]->dptr); kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = amax_list[tensor_id]->dptr; kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.scale_inv_list[pos] = scale_inv_list[tensor_id]->dptr; kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr;
kernel_args.num_rows_list[pos] = num_rows; kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length; kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles; kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles;
...@@ -358,28 +336,19 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -358,28 +336,19 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
void nvte_multi_cast_transpose(size_t num_tensors, void nvte_multi_cast_transpose(size_t num_tensors,
const NVTETensor* input_list, const NVTETensor* input_list,
const NVTETensor* scale_list,
NVTETensor* cast_output_list, NVTETensor* cast_output_list,
NVTETensor* transposed_output_list, NVTETensor* transposed_output_list,
NVTETensor* amax_list,
NVTETensor* scale_inv_list,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
std::vector<Tensor*> input_list_, scale_list_, std::vector<Tensor*> input_list_,
cast_output_list_, transposed_output_list_, amax_list_, scale_inv_list_; cast_output_list_, transposed_output_list_;
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i]))); input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
scale_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(scale_list[i])));
cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i])); cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i]));
transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i])); transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i]));
amax_list_.push_back(reinterpret_cast<Tensor*>(amax_list[i]));
scale_inv_list_.push_back(reinterpret_cast<Tensor*>(scale_inv_list[i]));
} }
multi_cast_transpose(input_list_, multi_cast_transpose(input_list_,
scale_list_,
cast_output_list_, cast_output_list_,
transposed_output_list_, transposed_output_list_,
amax_list_,
scale_inv_list_,
stream); stream);
} }
...@@ -244,19 +244,20 @@ transpose_kernel_notaligned(const IType * const input, ...@@ -244,19 +244,20 @@ transpose_kernel_notaligned(const IType * const input,
void transpose(const Tensor &input, void transpose(const Tensor &input,
Tensor *transposed_output, Tensor *transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2, "Output must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "Output must have 2 dimensions.");
const size_t row_length = input.shape[1]; const size_t row_length = input.data.shape[1];
const size_t num_rows = input.shape[0]; const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of output."); NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of output."); NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "Output is not allocated."); NVTE_CHECK(transposed_output->data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.dtype == transposed_output->dtype, "Input and output type must match."); NVTE_CHECK(input.data.dtype == transposed_output->data.dtype,
"Input and output type must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.dtype, Type, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type,
constexpr int type_size = sizeof(Type); constexpr int type_size = sizeof(Type);
constexpr int nvec_in = desired_load_size / type_size; constexpr int nvec_in = desired_load_size / type_size;
constexpr int nvec_out = desired_store_size / type_size; constexpr int nvec_out = desired_store_size / type_size;
...@@ -282,8 +283,8 @@ void transpose(const Tensor &input, ...@@ -282,8 +283,8 @@ void transpose(const Tensor &input,
cast_transpose_num_threads / n_warps_per_tile * cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>), (THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>),
stream>>>( stream>>>(
reinterpret_cast<const Type *>(input.dptr), reinterpret_cast<const Type *>(input.data.dptr),
reinterpret_cast<Type *>(transposed_output->dptr), reinterpret_cast<Type *>(transposed_output->data.dptr),
row_length, num_rows, n_tiles); row_length, num_rows, n_tiles);
} else { } else {
cudaFuncSetAttribute(transpose_kernel_notaligned<nvec_in, nvec_out, fp32, Type, Type>, cudaFuncSetAttribute(transpose_kernel_notaligned<nvec_in, nvec_out, fp32, Type, Type>,
...@@ -295,8 +296,8 @@ void transpose(const Tensor &input, ...@@ -295,8 +296,8 @@ void transpose(const Tensor &input,
cast_transpose_num_threads / n_warps_per_tile * cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>), (THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>),
stream>>>( stream>>>(
reinterpret_cast<const Type *>(input.dptr), reinterpret_cast<const Type *>(input.data.dptr),
reinterpret_cast<Type *>(transposed_output->dptr), reinterpret_cast<Type *>(transposed_output->data.dptr),
row_length, num_rows, n_tiles); row_length, num_rows, n_tiles);
} }
); // NOLINT(*) ); // NOLINT(*)
......
...@@ -30,113 +30,82 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) ...@@ -30,113 +30,82 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param)
} // namespace detail } // namespace detail
void fp8_quantize(const Tensor &input, void fp8_quantize(const Tensor &input,
const Tensor &scale,
Tensor *output, Tensor *output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(input.dtype != DType::kFloat8E4M3 && CheckInputTensor(input, "cast_input");
input.dtype != DType::kFloat8E5M2, CheckOutputTensor(*output, "cast_output");
"Input must be in higher precision.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(!is_fp8_dtype(input.data.dtype),
"Input must be in higher precision.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(output->dtype == DType::kFloat8E4M3 || NVTE_CHECK(is_fp8_dtype(output->data.dtype),
output->dtype == DType::kFloat8E5M2, "Output must have FP8 type.");
"Output must have FP8 type."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
NVTE_CHECK(output->shape == input.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale must have FP32 type."); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->data.dtype, OType,
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale must have 1 element."); constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
NVTE_CHECK(amax->dptr != nullptr, "AMAX is not allocated."); reinterpret_cast<const IType*>(input.data.dptr),
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX must have FP32 type."); reinterpret_cast<OType*>(output->data.dptr),
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX must have 1 element."); reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
NVTE_CHECK(scale_inv->dptr != nullptr, "Inverted scale is not allocated."); reinterpret_cast<fp32*>(output->amax.dptr),
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "Inverted scale must have FP32 type."); N,
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 }, "Inverted scale must have 1 element."); {},
stream);
const size_t N = product(input.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<const IType*>(input.dptr),
reinterpret_cast<OType*>(output->dptr),
reinterpret_cast<const fp32*>(scale.dptr),
reinterpret_cast<fp32*>(scale_inv->dptr),
reinterpret_cast<fp32*>(amax->dptr),
N,
{},
stream);
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*)
} }
void fp8_dequantize(const Tensor &input, void fp8_dequantize(const Tensor &input,
const Tensor &scale_inv,
Tensor *output, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(input.dtype == DType::kFloat8E4M3 || CheckInputTensor(input, "cast_input");
input.dtype == DType::kFloat8E5M2, CheckOutputTensor(*output, "cast_output");
"Input must have FP8 type."); NVTE_CHECK(is_fp8_dtype(input.data.dtype),
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated."); "Input must have FP8 type.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype),
NVTE_CHECK(output->dtype != DType::kFloat8E4M3 && "Output must be in higher precision.");
output->dtype != DType::kFloat8E5M2, NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
"Output must be in higher precision.");
NVTE_CHECK(output->shape == input.shape, "Input and output shapes need to match."); const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, IType,
NVTE_CHECK(scale_inv.dptr != nullptr, "Inverted scale is not allocated."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->data.dtype, OType,
NVTE_CHECK(scale_inv.dtype == DType::kFloat32, "Inverted scale must have FP32 type."); constexpr int nvec = 32 / sizeof(OType);
NVTE_CHECK(scale_inv.shape == std::vector<size_t>{ 1 }, "Inverted scale must have 1 element."); detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32*>(input.scale_inv.dptr);
const size_t N = product(input.shape); VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.dtype, IType, reinterpret_cast<const IType*>(input.data.dptr),
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->dtype, OType, reinterpret_cast<OType*>(output->data.dptr),
constexpr int nvec = 32 / sizeof(OType); nullptr,
detail::DequantizeParam p; nullptr,
p.scale_inv = reinterpret_cast<const fp32*>(scale_inv.dptr); nullptr,
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>( N,
reinterpret_cast<const IType*>(input.dptr), p,
reinterpret_cast<OType*>(output->dptr), stream);
nullptr,
nullptr,
nullptr,
N,
p,
stream);
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_fp8_quantize(const NVTETensor input, void nvte_fp8_quantize(const NVTETensor input,
const NVTETensor scale,
NVTETensor output, NVTETensor output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input), fp8_quantize(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream); stream);
} }
void nvte_fp8_dequantize(const NVTETensor input, void nvte_fp8_dequantize(const NVTETensor input,
const NVTETensor scale_inv,
NVTETensor output, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input), fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
...@@ -44,6 +44,41 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) ...@@ -44,6 +44,41 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor)
} }
transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr,
const std::vector<size_t>& shape,
const transformer_engine::DType type,
void* amax_ptr,
void* scale_ptr,
void* scale_inv_ptr) {
return transformer_engine::TensorWrapper(data_ptr, shape, type,
reinterpret_cast<float*>(amax_ptr),
reinterpret_cast<float*>(scale_ptr),
reinterpret_cast<float*>(scale_inv_ptr));
}
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor,
at::Tensor amax,
const at::Tensor scale,
at::Tensor scale_inv) {
transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
std::vector<size_t> shape;
for (auto s : tensor.sizes()) {
shape.push_back(s);
}
NVTE_CHECK(amax.scalar_type() == at::kFloat);
NVTE_CHECK(scale.scalar_type() == at::kFloat);
NVTE_CHECK(scale_inv.scalar_type() == at::kFloat);
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype,
amax.data_ptr(),
scale.data_ptr(),
scale_inv.data_ptr());
}
size_t product(const std::vector<size_t> &shape) { size_t product(const std::vector<size_t> &shape) {
size_t ret = 1; size_t ret = 1;
for (auto s : shape) { for (auto s : shape) {
...@@ -124,21 +159,16 @@ void dispatch_layernorm(void* input, // i ...@@ -124,21 +159,16 @@ void dispatch_layernorm(void* input, // i
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type); auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type); auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type);
auto beta_cu = makeTransformerEngineTensor(beta, beta_shape, beta_type); auto beta_cu = makeTransformerEngineTensor(beta, beta_shape, beta_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type); auto z_cu = makeTransformerEngineTensor(z, z_shape, z_type, amax, scale, scale_inv);
auto z_cu = makeTransformerEngineTensor(z, z_shape, z_type);
auto mu_cu = makeTransformerEngineTensor(mu, mu_shape, mu_type); auto mu_cu = makeTransformerEngineTensor(mu, mu_shape, mu_type);
auto rsigma_cu = makeTransformerEngineTensor(rsigma, rsigma_shape, rsigma_type); auto rsigma_cu = makeTransformerEngineTensor(rsigma, rsigma_shape, rsigma_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, scale_inv_type);
transformer_engine::TensorWrapper workspace, barrier; transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config // This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
scale_cu.data(), epsilon, epsilon, z_cu.data(), mu_cu.data(), rsigma_cu.data(),
z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount, at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(), workspace.data(), barrier.data());
scale_inv_cu.data());
// Fill workspace and barrier // Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(), auto workspace_data = allocateSpace(workspace.shape(),
...@@ -155,11 +185,9 @@ void dispatch_layernorm(void* input, // i ...@@ -155,11 +185,9 @@ void dispatch_layernorm(void* input, // i
// Actual call to fwd kernel // Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
scale_cu.data(), epsilon, epsilon, z_cu.data(), mu_cu.data(), rsigma_cu.data(),
z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount, at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(), workspace.data(), barrier.data());
scale_inv_cu.data());
} }
...@@ -184,17 +212,13 @@ void dispatch_cast_transpose_fusion(void* input, ...@@ -184,17 +212,13 @@ void dispatch_cast_transpose_fusion(void* input,
) { ) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type); auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cast_cu = makeTransformerEngineTensor(output_cast, output_cast_shape, auto output_cast_cu = makeTransformerEngineTensor(output_cast, output_cast_shape,
output_cast_type); output_cast_type, amax, scale,
scale_inv);
auto output_transpose_cu = makeTransformerEngineTensor(output_transpose, output_transpose_shape, auto output_transpose_cu = makeTransformerEngineTensor(output_transpose, output_transpose_shape,
output_transpose_type); output_transpose_type, amax,
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type); scale, scale_inv);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
scale_inv_type);
nvte_cast_transpose(input_cu.data(), scale_cu.data(),
output_cast_cu.data(), output_transpose_cu.data(),
amax_cu.data(), scale_inv_cu.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
...@@ -216,13 +240,10 @@ void dispatch_gelu(void* input, // i ...@@ -216,13 +240,10 @@ void dispatch_gelu(void* input, // i
const transformer_engine::DType scale_inv_type const transformer_engine::DType scale_inv_type
) { ) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type); auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type); auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type,
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type); amax, scale, scale_inv);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, scale_inv_type);
nvte_gelu(input_cu.data(), output_cu.data(), scale_cu.data(), nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
amax_cu.data(), scale_inv_cu.data(), at::cuda::getCurrentCUDAStream());
} }
...@@ -263,22 +284,18 @@ void dispatch_bgrad_cast_transpose_fusion(void* input, ...@@ -263,22 +284,18 @@ void dispatch_bgrad_cast_transpose_fusion(void* input,
const transformer_engine::DType scale_inv_type const transformer_engine::DType scale_inv_type
) { ) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type); auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape, auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type); cast_output_type, amax, scale,
scale_inv);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output, auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape, transposed_output_shape,
transposed_output_type); transposed_output_type,
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type); amax, scale, scale_inv);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type); auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv,
scale_inv_shape,
scale_inv_type);
transformer_engine::TensorWrapper workspace; transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), scale_cu.data(), cast_output_cu.data(), nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), amax_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
dbias_cu.data(), scale_inv_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace // Fill workspace
...@@ -287,10 +304,9 @@ void dispatch_bgrad_cast_transpose_fusion(void* input, ...@@ -287,10 +304,9 @@ void dispatch_bgrad_cast_transpose_fusion(void* input,
workspace.shape(), workspace.shape(),
workspace.dtype()); workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), scale_cu.data(), cast_output_cu.data(), nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), amax_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
dbias_cu.data(), scale_inv_cu.data(), workspace.data(), workspace.data(), at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
} }
...@@ -324,22 +340,19 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion( ...@@ -324,22 +340,19 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion(
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input, gelu_input_shape, auto gelu_input_cu = makeTransformerEngineTensor(gelu_input, gelu_input_shape,
gelu_input_type); gelu_input_type);
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type); auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape, auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type); cast_output_type, amax, scale,
scale_inv);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output, auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape, transposed_output_shape,
transposed_output_type); transposed_output_type,
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type); amax, scale, scale_inv);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type); auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv,
scale_inv_shape,
scale_inv_type);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), scale_cu.data(), nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(), cast_output_cu.data(), transposed_output_cu.data(),
amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(), dbias_cu.data(), workspace.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
// Fill workspace // Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -347,10 +360,10 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion( ...@@ -347,10 +360,10 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion(
workspace.shape(), workspace.shape(),
workspace.dtype()); workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), scale_cu.data(), nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(), cast_output_cu.data(), transposed_output_cu.data(),
amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(), dbias_cu.data(), workspace.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
...@@ -377,56 +390,51 @@ void dispatch_multi_cast_transpose( ...@@ -377,56 +390,51 @@ void dispatch_multi_cast_transpose(
transformer_engine::TensorWrapper workspace; transformer_engine::TensorWrapper workspace;
// Construct TE tensors // Construct TE tensors
std::vector<NVTETensor> input_list, scale_list, std::vector<NVTETensor> input_list,
cast_output_list, transposed_output_list, amax_list, scale_inv_list; cast_output_list, transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers; std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr, auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape, const std::vector<size_t>& shape,
transformer_engine::DType dtype) transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor { -> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype)); tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data(); return tensor_wrappers.back().data();
}; };
for (size_t i = 0; i < input_dptr_list.size(); ++i) { for (size_t i = 0; i < input_dptr_list.size(); ++i) {
input_list.emplace_back(make_tensor(input_dptr_list[i], input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i], input_shape_list[i],
input_type_list[i])); input_type_list[i],
scale_list.emplace_back(make_tensor(scale_dptr_list[i], nullptr,
scale_shape_list[i], nullptr,
scale_type_list[i])); nullptr));
cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i], cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i], cast_output_shape_list[i],
cast_output_type_list[i])); cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i], transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i], transposed_output_shape_list[i],
transposed_output_type_list[i])); transposed_output_type_list[i],
amax_list.emplace_back(make_tensor(amax_dptr_list[i], amax_dptr_list[i],
amax_shape_list[i], scale_dptr_list[i],
amax_type_list[i])); scale_inv_dptr_list[i]));
scale_inv_list.emplace_back(make_tensor(scale_inv_dptr_list[i],
scale_inv_shape_list[i],
scale_inv_type_list[i]));
} }
// Check tensor lists // Check tensor lists
NVTE_CHECK(scale_list.size() == input_list.size(),
"Number of input and scale tensors must match");
NVTE_CHECK(cast_output_list.size() == input_list.size(), NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match"); "Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(), NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match"); "Number of input and T output tensors must match");
NVTE_CHECK(amax_list.size() == input_list.size(),
"Number of input and AMAX tensors must match");
NVTE_CHECK(scale_inv_list.size() == input_list.size(),
"Number of input and scale_inv tensors must match");
// Launch TE kernel // Launch TE kernel
nvte_multi_cast_transpose(input_list.size(), nvte_multi_cast_transpose(input_list.size(),
input_list.data(), input_list.data(),
scale_list.data(),
cast_output_list.data(), cast_output_list.data(),
transposed_output_list.data(), transposed_output_list.data(),
amax_list.data(),
scale_inv_list.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
...@@ -109,6 +109,14 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, ...@@ -109,6 +109,14 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const transformer_engine::DType type const transformer_engine::DType type
); );
transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const std::vector<size_t>& shape,
const transformer_engine::DType type,
void* amax_ptr,
void* scale_ptr,
void* scale_inv_ptr
);
transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const NVTEShape& shape, const NVTEShape& shape,
...@@ -118,6 +126,11 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, ...@@ -118,6 +126,11 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor); transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor,
at::Tensor amax,
const at::Tensor scale,
at::Tensor scale_inv);
size_t product(const std::vector<size_t> &shape); size_t product(const std::vector<size_t> &shape);
......
...@@ -29,17 +29,13 @@ void te_gemm(at::Tensor A, ...@@ -29,17 +29,13 @@ void te_gemm(at::Tensor A,
auto te_A = makeTransformerEngineTensor(A.data_ptr(), auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)), {static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))}, static_cast<size_t>(A.size(1))},
A_type); A_type, nullptr, nullptr,
auto te_A_scale_inverse = makeTransformerEngineTensor(A_scale_inverse.data_ptr(), {1}, A_scale_inverse.data_ptr());
GetTransformerEngineDType(
A_scale_inverse.scalar_type()));
auto te_B = makeTransformerEngineTensor(B.data_ptr(), auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)), {static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))}, static_cast<size_t>(B.size(1))},
B_type); B_type, nullptr, nullptr,
auto te_B_scale_inverse = makeTransformerEngineTensor(B_scale_inverse.data_ptr(), {1}, B_scale_inverse.data_ptr());
GetTransformerEngineDType(
B_scale_inverse.scalar_type()));
auto te_D = makeTransformerEngineTensor(D.data_ptr(), auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)), {static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))}, static_cast<size_t>(D.size(1))},
...@@ -60,9 +56,7 @@ void te_gemm(at::Tensor A, ...@@ -60,9 +56,7 @@ void te_gemm(at::Tensor A,
DType::kByte); DType::kByte);
nvte_cublas_gemm(te_A.data(), nvte_cublas_gemm(te_A.data(),
te_A_scale_inverse.data(),
te_B.data(), te_B.data(),
te_B_scale_inverse.data(),
te_D.data(), te_D.data(),
te_bias.data(), te_bias.data(),
te_pre_gelu_out.data(), te_pre_gelu_out.data(),
...@@ -448,13 +442,11 @@ at::Tensor cast_to_fp8(const at::Tensor &input, ...@@ -448,13 +442,11 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype,
auto scale_cu = makeTransformerEngineTensor(scale.data_ptr(), {1}, DType::kFloat32); amax.data_ptr(), scale.data_ptr(),
auto amax_cu = makeTransformerEngineTensor(amax.data_ptr(), {1}, DType::kFloat32); scale_inv.data_ptr());
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32);
nvte_fp8_quantize(input_cu.data(), scale_cu.data(), output_cu.data(), nvte_fp8_quantize(input_cu.data(), output_cu.data(),
amax_cu.data(), scale_inv_cu.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return output; return output;
...@@ -472,11 +464,12 @@ at::Tensor cast_from_fp8(const at::Tensor &input, ...@@ -472,11 +464,12 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype,
nullptr, nullptr, scale_inv.data_ptr());
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32); auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32);
nvte_fp8_dequantize(input_cu.data(), scale_inv_cu.data(), output_cu.data(), nvte_fp8_dequantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return output; return output;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment