Unverified Commit a5ba71f3 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Move the amax/scale/scale_inv into the TE Tensor struct. (#33)



* Move the amax/scale/scale_inv into the TE Tensor struct.
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Handle multi_cast_transpose
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Changed softmax to new Tensor
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass at the cpp tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Round of fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix multi_cast_transpose
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cast_to_fp8
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 509bf877
......@@ -24,19 +24,13 @@ extern "C" {
* - `transposed_output` is the transposed result of the cast.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] scale Scaling factor used for outputs.
* \param[out] cast_output Result of the cast. Shape: [N, H].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream);
/*! \brief Transpose the input.
......@@ -60,23 +54,17 @@ void nvte_transpose(const NVTETensor input,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] scale Scaling factor used for outputs.
* \param[out] cast_output Result of the cast. Shape: [N, H].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] dbias Result of the reduction of the input along the
* first dimension. Shape: [H].
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace,
cudaStream_t stream);
......@@ -94,24 +82,18 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
* \param[in] input Input tensor of shape [N, H].
* \param[in] gelu_input Tensor used as input to the forward of GELU operation.
* Shape [N, H].
* \param[in] scale Scaling factor used for outputs.
* \param[out] cast_output Result of the cast. Shape: [N, H].
* \param[out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[in,out] amax AMAX value of the output tensor.
* \param[in,out] cast_output Result of the cast. Shape: [N, H].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N].
* \param[out] dbias Result of the reduction of the dGELU(input) along the
* first dimension. Shape: [H].
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace,
cudaStream_t stream);
......@@ -123,23 +105,17 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
*
* \param[in] num_tensors Number of tensors.
* \param[in] input_list List of 2D input tensors.
* \param[in] scale_list Scaling factor to generate outputs.
* \param[out] cast_output_list List of casted tensors. Dimensions
* \param[in,out] cast_output_list List of casted tensors. Dimensions
* match tensors in input_list.
* \param[out] transposed_output_list List of casted and transposed
* \param[in,out] transposed_output_list List of casted and transposed
* tensors. Dimensions are transpose
* of tensors in input_list.
* \param[in,out] amax_list AMAX values of the output tensors.
* \param[out] scale_inv_list Inverses of the scaling factors.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_cast_transpose(size_t num_tensors,
const NVTETensor* input_list,
const NVTETensor* scale_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
NVTETensor* amax_list,
NVTETensor* scale_inv_list,
cudaStream_t stream);
#ifdef __cplusplus
......
......@@ -141,10 +141,64 @@ extern BwdGeneralRegistry BWD_GENERAL_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
template<>
struct TypeId<fp8e4m3>{
constexpr static uint32_t Value = 3;
};
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdTunedRegistrar{
explicit FwdTunedRegistrar(FwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
FWD_TUNED_FUNCS.insert({ key, f });
}
};
......@@ -154,7 +208,7 @@ struct FwdTunedRegistrar{
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdGeneralRegistrar{
explicit FwdGeneralRegistrar(FwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(0);
uint64_t key = Types2Key<W, I, O, C>::get(0);
FWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f });
}
};
......@@ -164,7 +218,7 @@ struct FwdGeneralRegistrar{
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdTunedRegistrar{
explicit BwdTunedRegistrar(BwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
uint64_t key = Types2Key<W, I, O, C>::get(HIDDEN_SIZE);
BWD_TUNED_FUNCS.insert({ key, f });
}
};
......@@ -174,7 +228,7 @@ struct BwdTunedRegistrar{
template<typename W, typename I, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdGeneralRegistrar{
explicit BwdGeneralRegistrar(BwdFunction f){
uint64_t key = transformer_engine::Types2Key<W, I, O, C>::get(0);
uint64_t key = Types2Key<W, I, O, C>::get(0);
BWD_GENERAL_FUNCS[key].insert({ HIDDEN_SIZE, f });
}
};
......@@ -187,6 +241,8 @@ layer_norm::BwdFunction & get_bwd_launcher(DType wtype,
DType ctype,
uint32_t hidden_size);
//////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm
} // namespace transformer_engine
......
......@@ -144,7 +144,6 @@ size_t product(const std::vector<size_t> &shape) {
void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const Tensor& gamma, // hidden_size
const Tensor& beta, // hidden_size
const Tensor& scale,
const float epsilon,
Tensor* z,
Tensor* mu,
......@@ -152,50 +151,39 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
cudaStream_t stream,
const int multiprocessorCount,
Tensor* workspace,
Tensor* barrier,
Tensor* amax,
Tensor *scale_inv
) {
auto itype = x.dtype;
auto wtype = gamma.dtype;
auto otype = z->dtype;
bool fp8_out = otype == DType::kFloat8E4M3 ||
otype == DType::kFloat8E5M2;
auto ctype = layer_norm::DType::kFloat32;
Tensor* barrier) {
const auto itype = x.data.dtype;
const auto wtype = gamma.data.dtype;
const auto otype = z->data.dtype;
const bool fp8_out = is_fp8_dtype(otype);
const auto ctype = layer_norm::DType::kFloat32;
NVTE_CHECK(x.shape.size() == 2);
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
const size_t rows = x.shape[0];
const size_t cols = x.shape[1];
auto hidden_size = gamma.shape[0];
CheckOutputTensor(*z, "z");
CheckOutputTensor(*mu, "mu");
CheckOutputTensor(*rsigma, "rsigma");
NVTE_CHECK(gamma.shape == beta.shape);
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(epsilon >= 0.f);
const size_t rows = x.data.shape[0];
const size_t cols = x.data.shape[1];
const auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(z->dptr != nullptr);
NVTE_CHECK(z->shape == x.shape);
NVTE_CHECK(mu->shape == std::vector<size_t>{ rows });
NVTE_CHECK(mu->dtype == ctype);
NVTE_CHECK(gamma.data.shape == beta.data.shape);
NVTE_CHECK(hidden_size == cols);
NVTE_CHECK(rsigma->shape == std::vector<size_t>{ rows });
NVTE_CHECK(rsigma->dtype == ctype);
NVTE_CHECK(epsilon >= 0.f);
if (fp8_out) {
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(scale.dptr != nullptr);
NVTE_CHECK(scale.dtype == ctype);
NVTE_CHECK(z->data.shape == x.data.shape);
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 });
NVTE_CHECK(amax->dptr != nullptr);
NVTE_CHECK(amax->dtype == ctype);
NVTE_CHECK(mu->data.shape == std::vector<size_t>{ rows });
NVTE_CHECK(mu->data.dtype == ctype);
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 });
NVTE_CHECK(scale_inv->dptr != nullptr);
NVTE_CHECK(scale_inv->dtype == ctype);
}
NVTE_CHECK(rsigma->data.shape == std::vector<size_t>{ rows });
NVTE_CHECK(rsigma->data.dtype == ctype);
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
......@@ -210,49 +198,49 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.dptr;
params.mu = mu->dptr;
params.rs = rsigma->dptr;
params.gamma = gamma.dptr;
params.beta = beta.dptr;
params.z = z->dptr;
params.x = x.data.dptr;
params.mu = mu->data.dptr;
params.rs = rsigma->data.dptr;
params.gamma = gamma.data.dptr;
params.beta = beta.data.dptr;
params.z = z->data.dptr;
params.epsilon = epsilon;
params.amax = amax->dptr;
params.scale = scale.dptr;
params.scale_inv = scale_inv->dptr;
params.amax = z->amax.dptr;
params.scale = z->scale.dptr;
params.scale_inv = z->scale_inv.dptr;
params.fp8_out = fp8_out;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
if (workspace->dptr == nullptr) {
NVTE_CHECK(barrier->dptr == nullptr);
if (workspace->data.dptr == nullptr) {
NVTE_CHECK(barrier->data.dptr == nullptr);
workspace->dtype = layer_norm::DType::kByte;
workspace->data.dtype = layer_norm::DType::kByte;
if (launch_params.workspace_bytes == 0) {
launch_params.workspace_bytes = 1;
}
workspace->shape = { launch_params.workspace_bytes };
workspace->data.shape = { launch_params.workspace_bytes };
barrier->dtype = layer_norm::DType::kInt32;
barrier->shape = { launch_params.barrier_size };
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = { launch_params.barrier_size };
return;
}
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->dptr;
params.barrier = reinterpret_cast<int*>(barrier->dptr);
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
}
// Clear buffers
if ( params.fp8_out ) {
cudaMemsetAsync(params.amax, 0,
layer_norm::product(amax->shape) *
typeToSize(amax->dtype), stream);
layer_norm::product(z->amax.shape) *
typeToSize(z->amax.dtype), stream);
}
if ( launch_params.barrier_size > 0 ) {
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->shape) *
typeToSize(barrier->dtype), stream);
layer_norm::product(barrier->data.shape) *
typeToSize(barrier->data.dtype), stream);
}
// Launch the kernel.
......@@ -278,38 +266,44 @@ void layernorm_bwd(const Tensor& dz,
) {
using namespace transformer_engine;
auto itype = x.dtype;
auto wtype = gamma.dtype;
auto itype = x.data.dtype;
auto wtype = gamma.data.dtype;
auto otype = wtype;
auto ctype = DType::kFloat32;
NVTE_CHECK(dz.dtype == otype);
NVTE_CHECK(mu.dtype == ctype);
NVTE_CHECK(rsigma.dtype == ctype);
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
CheckInputTensor(rsigma, "rsigma");
CheckInputTensor(gamma, "gamma");
CheckOutputTensor(*dx, "dx");
CheckOutputTensor(*dgamma, "dgamma");
CheckOutputTensor(*dbeta, "dbeta");
NVTE_CHECK(dz.data.dtype == otype);
NVTE_CHECK(mu.data.dtype == ctype);
NVTE_CHECK(rsigma.data.dtype == ctype);
NVTE_CHECK(x.shape.size() == 2);
NVTE_CHECK(dz.shape == x.shape);
auto rows = x.shape[0];
auto cols = x.shape[1];
NVTE_CHECK(x.data.shape.size() == 2);
NVTE_CHECK(dz.data.shape == x.data.shape);
auto rows = x.data.shape[0];
auto cols = x.data.shape[1];
auto hidden_size = gamma.shape[0];
auto hidden_size = gamma.data.shape[0];
NVTE_CHECK(mu.shape[0] == rows);
NVTE_CHECK(mu.shape == rsigma.shape);
NVTE_CHECK(mu.data.shape[0] == rows);
NVTE_CHECK(mu.data.shape == rsigma.data.shape);
NVTE_CHECK(gamma.shape[0] == cols);
NVTE_CHECK(gamma.data.shape[0] == cols);
NVTE_CHECK(dx->shape == x.shape);
NVTE_CHECK(dx->dtype == x.dtype);
NVTE_CHECK(dx->dptr != nullptr);
NVTE_CHECK(dx->data.shape == x.data.shape);
NVTE_CHECK(dx->data.dtype == x.data.dtype);
NVTE_CHECK(dgamma->shape == gamma.shape);
NVTE_CHECK(dgamma->dtype == gamma.dtype);
NVTE_CHECK(dgamma->dptr != nullptr);
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
NVTE_CHECK(dbeta->shape == gamma.shape);
NVTE_CHECK(dbeta->dtype == gamma.dtype);
NVTE_CHECK(dbeta->dptr != nullptr);
NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = stream;
......@@ -322,47 +316,47 @@ void layernorm_bwd(const Tensor& dz,
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.dptr;
params.mu = mu.dptr;
params.rs = rsigma.dptr;
params.gamma = gamma.dptr;
params.dz = dz.dptr;
params.dx = dx->dptr;
params.dbeta = dbeta->dptr;
params.dgamma = dgamma->dptr;
params.dbeta_part = dbeta_part->dptr;
params.dgamma_part = dgamma_part->dptr;
params.x = x.data.dptr;
params.mu = mu.data.dptr;
params.rs = rsigma.data.dptr;
params.gamma = gamma.data.dptr;
params.dz = dz.data.dptr;
params.dx = dx->data.dptr;
params.dbeta = dbeta->data.dptr;
params.dgamma = dgamma->data.dptr;
params.dbeta_part = dbeta_part->data.dptr;
params.dgamma_part = dgamma_part->data.dptr;
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
// Populate shape and dtypes for FW to allocate memory
if (dgamma_part->dptr == nullptr) {
NVTE_CHECK(dbeta_part->dptr == nullptr);
if (dgamma_part->data.dptr == nullptr) {
NVTE_CHECK(dbeta_part->data.dptr == nullptr);
dgamma_part->dtype = ctype;
dgamma_part->shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size };
dgamma_part->data.dtype = ctype;
dgamma_part->data.shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size };
dbeta_part->dtype = ctype;
dbeta_part->shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size };
dbeta_part->data.dtype = ctype;
dbeta_part->data.shape = { static_cast<uint64_t> (launch_params.params.ctas_per_col),
hidden_size };
workspace->dtype = layer_norm::DType::kByte;
workspace->shape = { launch_params.workspace_bytes };
workspace->data.dtype = layer_norm::DType::kByte;
workspace->data.shape = { launch_params.workspace_bytes };
barrier->dtype = layer_norm::DType::kInt32;
barrier->shape = { launch_params.barrier_size };
barrier->data.dtype = layer_norm::DType::kInt32;
barrier->data.shape = { launch_params.barrier_size };
return;
}
if ( launch_params.barrier_size > 0 ) {
params.workspace = workspace->dptr;
params.barrier = reinterpret_cast<int*>(barrier->dptr);
params.workspace = workspace->data.dptr;
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
cudaMemsetAsync(params.barrier, 0,
layer_norm::product(barrier->shape) *
typeToSize(barrier->dtype), stream);
layer_norm::product(barrier->data.shape) *
typeToSize(barrier->data.dtype), stream);
}
// Launch the kernel.
......@@ -373,7 +367,6 @@ void layernorm_bwd(const Tensor& dz,
void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const NVTETensor gamma, // hidden_size
const NVTETensor beta, // hidden_size
const NVTETensor scale, // 1
const float epsilon,
NVTETensor z,
NVTETensor mu,
......@@ -381,14 +374,11 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier,
NVTETensor amax,
NVTETensor scale_inv) {
NVTETensor barrier) {
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta),
*reinterpret_cast<const Tensor*>(scale),
epsilon,
reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu),
......@@ -396,9 +386,7 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
stream,
multiprocessorCount,
reinterpret_cast<Tensor*>(workspace),
reinterpret_cast<Tensor*>(barrier),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv));
reinterpret_cast<Tensor*>(barrier));
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
......
......@@ -15,15 +15,77 @@ size_t typeToSize(const transformer_engine::DType type) {
); // NOLINT(*)
}
bool is_fp8_dtype(const transformer_engine::DType t) {
return t == transformer_engine::DType::kFloat8E4M3 ||
t == transformer_engine::DType::kFloat8E5M2;
}
void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.data.dtype;
if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv
NVTE_CHECK(t.scale_inv.dptr != nullptr,
"FP8 input " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 });
} else {
NVTE_CHECK(t.scale.dptr == nullptr,
"Scale is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input " + name + ".");
}
NVTE_CHECK(t.data.dptr != nullptr,
"Input " + name + " is not allocated!");
}
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.data.dtype;
if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, amax and scale_inv
NVTE_CHECK(t.amax.dptr != nullptr,
"FP8 output " + name + " must have amax tensor.");
NVTE_CHECK(t.amax.dtype == DType::kFloat32);
NVTE_CHECK(t.amax.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(t.scale_inv.dptr != nullptr,
"FP8 output " + name + " must have scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(t.scale.dptr != nullptr,
"FP8 output " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale.dtype == DType::kFloat32);
NVTE_CHECK(t.scale.shape == std::vector<size_t>{ 1 });
} else {
NVTE_CHECK(t.scale.dptr == nullptr,
"Scale is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 output " + name + ".");
}
if (!allow_empty) {
NVTE_CHECK(t.data.dptr != nullptr,
"Output " + name + " is not allocated!");
}
}
} // namespace transformer_engine
NVTETensor nvte_create_tensor(void *dptr,
const NVTEShape shape,
const NVTEDType dtype) {
const NVTEDType dtype,
float *amax,
float *scale,
float *scale_inv) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->dptr = dptr;
ret->shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
ret->dtype = static_cast<transformer_engine::DType>(dtype);
ret->data.dptr = dptr;
ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
ret->data.dtype = static_cast<transformer_engine::DType>(dtype);
ret->amax.dptr = amax;
ret->scale.dptr = scale;
ret->scale_inv.dptr = scale_inv;
return ret;
}
......@@ -34,18 +96,40 @@ void nvte_destroy_tensor(NVTETensor tensor) {
}
NVTEDType nvte_tensor_type(const NVTETensor tensor) {
return static_cast<NVTEDType>(reinterpret_cast<const transformer_engine::Tensor*>(tensor)->dtype);
return static_cast<NVTEDType>(
reinterpret_cast<const transformer_engine::Tensor*>(tensor)->data.dtype);
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTEShape ret;
ret.data = t.shape.data();
ret.ndim = t.shape.size();
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
}
void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
return t.dptr;
return t.data.dptr;
}
float *nvte_tensor_amax(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!");
return reinterpret_cast<float*>(t.amax.dptr);
}
float *nvte_tensor_scale(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale.dptr);
}
float *nvte_tensor_scale_inv(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
"Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale_inv.dptr);
}
......@@ -105,7 +105,7 @@ cast_transpose_kernel(const IType * const input,
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *scale_ptr;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
......@@ -158,8 +158,8 @@ cast_transpose_kernel(const IType * const input,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, max);
reciprocal<float>(scale_inv, scale);
if (amax != nullptr) atomicMaxFloat(amax, max);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
}
......@@ -222,7 +222,7 @@ cast_transpose_kernel_notaligned(const IType * const input,
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *scale_ptr;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
......@@ -294,48 +294,41 @@ cast_transpose_kernel_notaligned(const IType * const input,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, max);
reciprocal<float>(scale_inv, scale);
if (amax != nullptr) atomicMaxFloat(amax, max);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
}
void cast_transpose(const Tensor &input,
const Tensor &scale,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.shape == cast_output->shape, "Input and C output must have the same shape.");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->dtype == transposed_output->dtype,
"Both C and T outputs need to have the same type.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
CheckInputTensor(input, "cast_transpose_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size / itype_size;
......@@ -363,12 +356,12 @@ void cast_transpose(const Tensor &input,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.dptr),
reinterpret_cast<OutputType *>(cast_output->dptr),
reinterpret_cast<OutputType *>(transposed_output->dptr),
reinterpret_cast<const fp32 *>(scale.dptr),
reinterpret_cast<fp32 *>(amax->dptr),
reinterpret_cast<fp32 *>(scale_inv->dptr),
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32,
......@@ -381,12 +374,12 @@ void cast_transpose(const Tensor &input,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.dptr),
reinterpret_cast<OutputType *>(cast_output->dptr),
reinterpret_cast<OutputType *>(transposed_output->dptr),
reinterpret_cast<const fp32 *>(scale.dptr),
reinterpret_cast<fp32 *>(amax->dptr),
reinterpret_cast<fp32 *>(scale_inv->dptr),
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
}
); // NOLINT(*)
......@@ -396,18 +389,12 @@ void cast_transpose(const Tensor &input,
} // namespace transformer_engine
void nvte_cast_transpose(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) {
using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream);
}
......@@ -162,7 +162,7 @@ cast_transpose_dbias_kernel(const Param param,
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *param.scale_ptr;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear();
......@@ -240,8 +240,8 @@ cast_transpose_dbias_kernel(const Param param,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
}
}
......@@ -310,7 +310,7 @@ cast_transpose_dbias_kernel_notaligned(const Param param,
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *param.scale_ptr;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear();
......@@ -409,8 +409,8 @@ cast_transpose_dbias_kernel_notaligned(const Param param,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
}
}
......@@ -456,16 +456,16 @@ reduce_dbias_kernel(OutputType* const dbias_output,
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor* workspace,
const int nvec_out) {
const size_t row_length = cast_output.shape[1];
const size_t num_rows = cast_output.shape[0];
const size_t row_length = cast_output.data.shape[1];
const size_t num_rows = cast_output.data.shape[0];
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
workspace->shape = {num_rows_partial_dbias, row_length};
workspace->dtype = DType::kFloat32;
workspace->data.shape = {num_rows_partial_dbias, row_length};
workspace->data.dtype = DType::kFloat32;
}
template <typename InputType>
......@@ -489,61 +489,54 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias,
reduce_dbias_num_threads,
0,
stream>>>(
reinterpret_cast<InputType *>(dbias->dptr),
reinterpret_cast<const fp32 *>(workspace.dptr),
reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr),
reduce_dbias_row_length,
reduce_dbias_num_rows);
}
void cast_transpose_dbias(const Tensor &input,
const Tensor &scale,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *amax,
Tensor *dbias,
Tensor *scale_inv,
Tensor *workspace,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.shape == cast_output->shape, "Input and C output must have the same shape.");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->dtype == transposed_output->dtype,
"Both T and C outputs need to have the same type.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated.");
NVTE_CHECK(dbias->dptr != nullptr, "DBias is not allocated.");
NVTE_CHECK(dbias->dtype == input.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
CheckInputTensor(input, "cast_transpose_dbias_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size / itype_size;
constexpr int nvec_out = desired_store_size / otype_size;
if (workspace->dptr == nullptr) {
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
}
......@@ -567,13 +560,13 @@ void cast_transpose_dbias(const Tensor &input,
static_assert(shared_size_transpose >= shared_size_dbias);
using Param = CTDBiasParam<InputType, OutputType, ComputeType>;
Param param;
param.input = reinterpret_cast<const InputType *>(input.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(amax->dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(scale_inv->dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->dptr);
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_kernel<nvec_in, nvec_out, Param>,
......@@ -678,7 +671,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *param.scale_ptr;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear();
......@@ -769,8 +762,8 @@ cast_transpose_dbias_dgelu_kernel(const Param param,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
}
}
......@@ -846,7 +839,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = *param.scale_ptr;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear();
......@@ -960,61 +953,53 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(param.amax, max);
reciprocal<CType>(param.scale_inv, scale);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
if (param.scale_inv != nullptr) reciprocal<CType>(param.scale_inv, scale);
}
}
void cast_transpose_dbias_dgelu(const Tensor &input,
const Tensor &gelu_input,
const Tensor &scale,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *amax,
Tensor *dbias,
Tensor *scale_inv,
Tensor *workspace,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2,
CheckInputTensor(input, "cast_transpose_dbias_dgelu_input");
CheckInputTensor(gelu_input, "gelu_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2,
"T output must have 2 dimensions.");
NVTE_CHECK(input.shape == cast_output->shape,
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->dtype == transposed_output->dtype,
"Both C and T outputs need to have the same type.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(gelu_input.dptr != nullptr, "GeLU input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(cast_output->dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv output is not allocated.");
NVTE_CHECK(dbias->dptr != nullptr, "DBias is not allocated.");
NVTE_CHECK(dbias->dtype == input.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
NVTE_CHECK(input.dtype == gelu_input.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.shape == gelu_input.shape, "Shapes of both inputs must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->dtype, OutputType,
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
NVTE_CHECK(input.data.dtype == gelu_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == gelu_input.data.shape, "Shapes of both inputs must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType;
/* dgelu fusion kernel uses more registers */
constexpr int desired_load_size_dgelu = 4;
......@@ -1024,7 +1009,7 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
constexpr int nvec_in = desired_load_size_dgelu / itype_size;
constexpr int nvec_out = desired_store_size_dgelu / otype_size;
if (workspace->dptr == nullptr) {
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
}
......@@ -1048,14 +1033,14 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
static_assert(shared_size_transpose >= shared_size_dbias);
using Param = CTDBiasDGeluParam<InputType, InputType2, OutputType, ComputeType>;
Param param;
param.input = reinterpret_cast<const InputType *>(input.dptr);
param.gelu_input = reinterpret_cast<const InputType2 *>(gelu_input.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(amax->dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(scale_inv->dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->dptr);
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.gelu_input = reinterpret_cast<const InputType2 *>(gelu_input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
......@@ -1084,45 +1069,33 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
} // namespace transformer_engine
void nvte_cast_transpose_dbias(const NVTETensor input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace,
cudaStream_t stream) {
using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(workspace),
stream);
}
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor gelu_input,
const NVTETensor scale,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor amax,
NVTETensor dbias,
NVTETensor scale_inv,
NVTETensor workspace,
cudaStream_t stream) {
using namespace transformer_engine;
cast_transpose_dbias_dgelu(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gelu_input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(workspace),
stream);
}
......@@ -87,7 +87,8 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]);
OType* output_c = reinterpret_cast<OType*>(args.output_c_list[tensor_id]);
OType* output_t = reinterpret_cast<OType*>(args.output_t_list[tensor_id]);
const CType scale = *reinterpret_cast<CType*>(args.scale_list[tensor_id]);
const CType* scale_ptr = reinterpret_cast<CType*>(args.scale_list[tensor_id]);
const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr;
CType* amax = reinterpret_cast<CType*>(args.amax_list[tensor_id]);
CType* scale_inv = reinterpret_cast<CType*>(args.scale_inv_list[tensor_id]);
const int num_rows = args.num_rows_list[tensor_id];
......@@ -190,79 +191,56 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
local_amax = reduce_max<n_warps_per_tile>(local_amax, tidy);
if (tid == 0) {
static_assert(std::is_same<CType, float>::value);
atomicMaxFloat(amax, local_amax);
if (amax != nullptr) atomicMaxFloat(amax, local_amax);
}
if (tid == 0 && tile_id == 0) {
reciprocal<float>(scale_inv, scale);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
}
} // namespace
void multi_cast_transpose(const std::vector<Tensor*> input_list,
const std::vector<Tensor*> scale_list,
std::vector<Tensor*> cast_output_list,
std::vector<Tensor*> transposed_output_list,
std::vector<Tensor*> amax_list,
std::vector<Tensor*> scale_inv_list,
cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(scale_list.size() == input_list.size(),
"Number of input and scale tensors must match");
NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match");
NVTE_CHECK(amax_list.size() == input_list.size(),
"Number of input and AMAX tensors must match");
NVTE_CHECK(scale_inv_list.size() == input_list.size(),
"Number of input and scale_inv tensors must match");
if (input_list.empty()) {
return;
}
// Check that tensor properties are valid
DType ctype = DType::kFloat32;
DType itype = input_list[0]->dtype;
DType otype = cast_output_list[0]->dtype;
DType itype = input_list[0]->data.dtype;
DType otype = cast_output_list[0]->data.dtype;
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
const auto& input = *input_list[tensor_id];
const auto& scale = *scale_list[tensor_id];
const auto& cast_output = *cast_output_list[tensor_id];
const auto& transposed_output = *transposed_output_list[tensor_id];
const auto& amax = *amax_list[tensor_id];
const auto& scale_inv = *scale_inv_list[tensor_id];
CheckInputTensor(input, "multi_cast_transpose_input_" + std::to_string(tensor_id));
CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id));
CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.dtype == itype,
NVTE_CHECK(input.data.dtype == itype,
"Input tensor types do not match.");
NVTE_CHECK(scale.dtype == ctype,
"Scale tensor must have Float32 type.");
NVTE_CHECK(cast_output.dtype == otype,
NVTE_CHECK(cast_output.data.dtype == otype,
"C output tensor types do not match.");
NVTE_CHECK(transposed_output.dtype == otype,
NVTE_CHECK(transposed_output.data.dtype == otype,
"T output tensor types do not match.");
NVTE_CHECK(amax.dtype == ctype,
"AMAX tensor must have Float32 type.");
NVTE_CHECK(scale_inv.dtype == ctype,
"scale_inv tensor must have Float32 type.");
NVTE_CHECK(input.shape.size() == 2,
NVTE_CHECK(input.data.shape.size() == 2,
"Input tensor must have 2 dimensions.");
NVTE_CHECK(cast_output.shape == input.shape,
NVTE_CHECK(cast_output.data.shape == input.data.shape,
"C output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape.size() == 2,
NVTE_CHECK(transposed_output.data.shape.size() == 2,
"T output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape[0] == input.shape[1],
NVTE_CHECK(transposed_output.data.shape[0] == input.data.shape[1],
"T output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.shape[1] == input.shape[0],
NVTE_CHECK(transposed_output.data.shape[1] == input.data.shape[0],
"T output tensor shape does not match input tensor.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(cast_output.dptr != nullptr, "C output is not allocated.");
NVTE_CHECK(transposed_output.dptr != nullptr, "T output is not allocated.");
NVTE_CHECK(amax.dptr != nullptr, "AMAX output is not allocated.");
NVTE_CHECK(scale_inv.dptr != nullptr, "scale_inv output is not allocated.");
}
// Input matrices are divided into tiles
......@@ -304,8 +282,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
}
// Calculate number of thread blocks needed for tensor
const int num_rows = input_list[tensor_id]->shape[0];
const int row_length = input_list[tensor_id]->shape[1];
const int num_rows = input_list[tensor_id]->data.shape[0];
const int row_length = input_list[tensor_id]->data.shape[1];
const int num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m;
const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n;
const int num_tiles = num_tiles_m * num_tiles_n;
......@@ -317,12 +295,12 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
// Add tensor to kernel argument struct
const int pos = kernel_args.num_tensors;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->dptr);
kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->dptr;
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->dptr;
kernel_args.scale_list[pos] = const_cast<void*>(scale_list[tensor_id]->dptr);
kernel_args.amax_list[pos] = amax_list[tensor_id]->dptr;
kernel_args.scale_inv_list[pos] = scale_inv_list[tensor_id]->dptr;
kernel_args.input_list[pos] = const_cast<void*>(input_list[tensor_id]->data.dptr);
kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->data.dptr;
kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr;
kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr;
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles;
......@@ -358,28 +336,19 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
void nvte_multi_cast_transpose(size_t num_tensors,
const NVTETensor* input_list,
const NVTETensor* scale_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
NVTETensor* amax_list,
NVTETensor* scale_inv_list,
cudaStream_t stream) {
using namespace transformer_engine;
std::vector<Tensor*> input_list_, scale_list_,
cast_output_list_, transposed_output_list_, amax_list_, scale_inv_list_;
std::vector<Tensor*> input_list_,
cast_output_list_, transposed_output_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
scale_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(scale_list[i])));
cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i]));
transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i]));
amax_list_.push_back(reinterpret_cast<Tensor*>(amax_list[i]));
scale_inv_list_.push_back(reinterpret_cast<Tensor*>(scale_inv_list[i]));
}
multi_cast_transpose(input_list_,
scale_list_,
cast_output_list_,
transposed_output_list_,
amax_list_,
scale_inv_list_,
stream);
}
......@@ -244,19 +244,20 @@ transpose_kernel_notaligned(const IType * const input,
void transpose(const Tensor &input,
Tensor *transposed_output,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(transposed_output->shape.size() == 2, "Output must have 2 dimensions.");
const size_t row_length = input.shape[1];
const size_t num_rows = input.shape[0];
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "Output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->shape[0] == row_length, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->shape[1] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(transposed_output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.dtype == transposed_output->dtype, "Input and output type must match.");
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(transposed_output->data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.data.dtype == transposed_output->data.dtype,
"Input and output type must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.dtype, Type,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type,
constexpr int type_size = sizeof(Type);
constexpr int nvec_in = desired_load_size / type_size;
constexpr int nvec_out = desired_store_size / type_size;
......@@ -282,8 +283,8 @@ void transpose(const Tensor &input,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>),
stream>>>(
reinterpret_cast<const Type *>(input.dptr),
reinterpret_cast<Type *>(transposed_output->dptr),
reinterpret_cast<const Type *>(input.data.dptr),
reinterpret_cast<Type *>(transposed_output->data.dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(transpose_kernel_notaligned<nvec_in, nvec_out, fp32, Type, Type>,
......@@ -295,8 +296,8 @@ void transpose(const Tensor &input,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>),
stream>>>(
reinterpret_cast<const Type *>(input.dptr),
reinterpret_cast<Type *>(transposed_output->dptr),
reinterpret_cast<const Type *>(input.data.dptr),
reinterpret_cast<Type *>(transposed_output->data.dptr),
row_length, num_rows, n_tiles);
}
); // NOLINT(*)
......
......@@ -30,113 +30,82 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param)
} // namespace detail
void fp8_quantize(const Tensor &input,
const Tensor &scale,
Tensor *output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) {
NVTE_CHECK(input.dtype != DType::kFloat8E4M3 &&
input.dtype != DType::kFloat8E5M2,
"Input must be in higher precision.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(output->dtype == DType::kFloat8E4M3 ||
output->dtype == DType::kFloat8E5M2,
"Output must have FP8 type.");
NVTE_CHECK(output->shape == input.shape, "Input and output shapes need to match.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale must have FP32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale must have 1 element.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX is not allocated.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX must have FP32 type.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX must have 1 element.");
NVTE_CHECK(scale_inv->dptr != nullptr, "Inverted scale is not allocated.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "Inverted scale must have FP32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 }, "Inverted scale must have 1 element.");
const size_t N = product(input.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<const IType*>(input.dptr),
reinterpret_cast<OType*>(output->dptr),
reinterpret_cast<const fp32*>(scale.dptr),
reinterpret_cast<fp32*>(scale_inv->dptr),
reinterpret_cast<fp32*>(amax->dptr),
N,
{},
stream);
); // NOLINT(*)
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(!is_fp8_dtype(input.data.dtype),
"Input must be in higher precision.");
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"Output must have FP8 type.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
N,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void fp8_dequantize(const Tensor &input,
const Tensor &scale_inv,
Tensor *output,
cudaStream_t stream) {
NVTE_CHECK(input.dtype == DType::kFloat8E4M3 ||
input.dtype == DType::kFloat8E5M2,
"Input must have FP8 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(output->dtype != DType::kFloat8E4M3 &&
output->dtype != DType::kFloat8E5M2,
"Output must be in higher precision.");
NVTE_CHECK(output->shape == input.shape, "Input and output shapes need to match.");
NVTE_CHECK(scale_inv.dptr != nullptr, "Inverted scale is not allocated.");
NVTE_CHECK(scale_inv.dtype == DType::kFloat32, "Inverted scale must have FP32 type.");
NVTE_CHECK(scale_inv.shape == std::vector<size_t>{ 1 }, "Inverted scale must have 1 element.");
const size_t N = product(input.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32*>(scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType*>(input.dptr),
reinterpret_cast<OType*>(output->dptr),
nullptr,
nullptr,
nullptr,
N,
p,
stream);
); // NOLINT(*)
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(is_fp8_dtype(input.data.dtype),
"Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype),
"Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32*>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
nullptr,
nullptr,
nullptr,
N,
p,
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_fp8_quantize(const NVTETensor input,
const NVTETensor scale,
NVTETensor output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) {
using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream);
}
void nvte_fp8_dequantize(const NVTETensor input,
const NVTETensor scale_inv,
NVTETensor output,
cudaStream_t stream) {
using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale_inv),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -44,6 +44,41 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor)
}
transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr,
const std::vector<size_t>& shape,
const transformer_engine::DType type,
void* amax_ptr,
void* scale_ptr,
void* scale_inv_ptr) {
return transformer_engine::TensorWrapper(data_ptr, shape, type,
reinterpret_cast<float*>(amax_ptr),
reinterpret_cast<float*>(scale_ptr),
reinterpret_cast<float*>(scale_inv_ptr));
}
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor,
at::Tensor amax,
const at::Tensor scale,
at::Tensor scale_inv) {
transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
std::vector<size_t> shape;
for (auto s : tensor.sizes()) {
shape.push_back(s);
}
NVTE_CHECK(amax.scalar_type() == at::kFloat);
NVTE_CHECK(scale.scalar_type() == at::kFloat);
NVTE_CHECK(scale_inv.scalar_type() == at::kFloat);
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype,
amax.data_ptr(),
scale.data_ptr(),
scale_inv.data_ptr());
}
size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (auto s : shape) {
......@@ -124,21 +159,16 @@ void dispatch_layernorm(void* input, // i
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type);
auto beta_cu = makeTransformerEngineTensor(beta, beta_shape, beta_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto z_cu = makeTransformerEngineTensor(z, z_shape, z_type);
auto z_cu = makeTransformerEngineTensor(z, z_shape, z_type, amax, scale, scale_inv);
auto mu_cu = makeTransformerEngineTensor(mu, mu_shape, mu_type);
auto rsigma_cu = makeTransformerEngineTensor(rsigma, rsigma_shape, rsigma_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, scale_inv_type);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
scale_cu.data(), epsilon,
z_cu.data(), mu_cu.data(), rsigma_cu.data(),
epsilon, z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(),
scale_inv_cu.data());
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
......@@ -155,11 +185,9 @@ void dispatch_layernorm(void* input, // i
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
scale_cu.data(), epsilon,
z_cu.data(), mu_cu.data(), rsigma_cu.data(),
epsilon, z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data(), amax_cu.data(),
scale_inv_cu.data());
workspace.data(), barrier.data());
}
......@@ -184,17 +212,13 @@ void dispatch_cast_transpose_fusion(void* input,
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cast_cu = makeTransformerEngineTensor(output_cast, output_cast_shape,
output_cast_type);
output_cast_type, amax, scale,
scale_inv);
auto output_transpose_cu = makeTransformerEngineTensor(output_transpose, output_transpose_shape,
output_transpose_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape,
scale_inv_type);
nvte_cast_transpose(input_cu.data(), scale_cu.data(),
output_cast_cu.data(), output_transpose_cu.data(),
amax_cu.data(), scale_inv_cu.data(),
output_transpose_type, amax,
scale, scale_inv);
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -216,13 +240,10 @@ void dispatch_gelu(void* input, // i
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, scale_inv_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type,
amax, scale, scale_inv);
nvte_gelu(input_cu.data(), output_cu.data(), scale_cu.data(),
amax_cu.data(), scale_inv_cu.data(), at::cuda::getCurrentCUDAStream());
nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
......@@ -263,22 +284,18 @@ void dispatch_bgrad_cast_transpose_fusion(void* input,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type);
cast_output_type, amax, scale,
scale_inv);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape,
transposed_output_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
transposed_output_type,
amax, scale, scale_inv);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv,
scale_inv_shape,
scale_inv_type);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), scale_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), amax_cu.data(),
dbias_cu.data(), scale_inv_cu.data(),
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
......@@ -287,10 +304,9 @@ void dispatch_bgrad_cast_transpose_fusion(void* input,
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), scale_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), amax_cu.data(),
dbias_cu.data(), scale_inv_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
}
......@@ -324,22 +340,19 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion(
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input, gelu_input_shape,
gelu_input_type);
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto scale_cu = makeTransformerEngineTensor(scale, scale_shape, scale_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type);
cast_output_type, amax, scale,
scale_inv);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape,
transposed_output_type);
auto amax_cu = makeTransformerEngineTensor(amax, amax_shape, amax_type);
transposed_output_type,
amax, scale, scale_inv);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv,
scale_inv_shape,
scale_inv_type);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), scale_cu.data(),
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -347,10 +360,10 @@ void dispatch_bgrad_dgelu_cast_transpose_fusion(
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), scale_cu.data(),
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -377,56 +390,51 @@ void dispatch_multi_cast_transpose(
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> input_list, scale_list,
cast_output_list, transposed_output_list, amax_list, scale_inv_list;
std::vector<NVTETensor> input_list,
cast_output_list, transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype)
transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i]));
scale_list.emplace_back(make_tensor(scale_dptr_list[i],
scale_shape_list[i],
scale_type_list[i]));
input_type_list[i],
nullptr,
nullptr,
nullptr));
cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i]));
cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i]));
amax_list.emplace_back(make_tensor(amax_dptr_list[i],
amax_shape_list[i],
amax_type_list[i]));
scale_inv_list.emplace_back(make_tensor(scale_inv_dptr_list[i],
scale_inv_shape_list[i],
scale_inv_type_list[i]));
transposed_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
}
// Check tensor lists
NVTE_CHECK(scale_list.size() == input_list.size(),
"Number of input and scale tensors must match");
NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match");
NVTE_CHECK(amax_list.size() == input_list.size(),
"Number of input and AMAX tensors must match");
NVTE_CHECK(scale_inv_list.size() == input_list.size(),
"Number of input and scale_inv tensors must match");
// Launch TE kernel
nvte_multi_cast_transpose(input_list.size(),
input_list.data(),
scale_list.data(),
cast_output_list.data(),
transposed_output_list.data(),
amax_list.data(),
scale_inv_list.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -109,6 +109,14 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const transformer_engine::DType type
);
transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const std::vector<size_t>& shape,
const transformer_engine::DType type,
void* amax_ptr,
void* scale_ptr,
void* scale_inv_ptr
);
transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
const NVTEShape& shape,
......@@ -118,6 +126,11 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor,
at::Tensor amax,
const at::Tensor scale,
at::Tensor scale_inv);
size_t product(const std::vector<size_t> &shape);
......
......@@ -29,17 +29,13 @@ void te_gemm(at::Tensor A,
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
{static_cast<size_t>(A.size(0)),
static_cast<size_t>(A.size(1))},
A_type);
auto te_A_scale_inverse = makeTransformerEngineTensor(A_scale_inverse.data_ptr(), {1},
GetTransformerEngineDType(
A_scale_inverse.scalar_type()));
A_type, nullptr, nullptr,
A_scale_inverse.data_ptr());
auto te_B = makeTransformerEngineTensor(B.data_ptr(),
{static_cast<size_t>(B.size(0)),
static_cast<size_t>(B.size(1))},
B_type);
auto te_B_scale_inverse = makeTransformerEngineTensor(B_scale_inverse.data_ptr(), {1},
GetTransformerEngineDType(
B_scale_inverse.scalar_type()));
B_type, nullptr, nullptr,
B_scale_inverse.data_ptr());
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
......@@ -60,9 +56,7 @@ void te_gemm(at::Tensor A,
DType::kByte);
nvte_cublas_gemm(te_A.data(),
te_A_scale_inverse.data(),
te_B.data(),
te_B_scale_inverse.data(),
te_D.data(),
te_bias.data(),
te_pre_gelu_out.data(),
......@@ -448,13 +442,11 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype);
auto scale_cu = makeTransformerEngineTensor(scale.data_ptr(), {1}, DType::kFloat32);
auto amax_cu = makeTransformerEngineTensor(amax.data_ptr(), {1}, DType::kFloat32);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_fp8_quantize(input_cu.data(), scale_cu.data(), output_cu.data(),
amax_cu.data(), scale_inv_cu.data(),
nvte_fp8_quantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return output;
......@@ -472,11 +464,12 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype,
nullptr, nullptr, scale_inv.data_ptr());
auto output_cu = makeTransformerEngineTensor(output);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32);
nvte_fp8_dequantize(input_cu.data(), scale_inv_cu.data(), output_cu.data(),
nvte_fp8_dequantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return output;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment