Unverified Commit a5ba71f3 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Move the amax/scale/scale_inv into the TE Tensor struct. (#33)



* Move the amax/scale/scale_inv into the TE Tensor struct.
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Handle multi_cast_transpose
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Changed softmax to new Tensor
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass at the cpp tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Round of fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix multi_cast_transpose
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cast_to_fp8
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 509bf877
......@@ -47,29 +47,28 @@ void performTest(const size_t N, const size_t H) {
Tensor input({ N, H }, itype);
Tensor output_c({ N, H }, otype);
Tensor output_t({ H, N }, otype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(input);
fillUniform(scale);
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
nvte_cast_transpose(input.data(), scale.data(), output_c.data(), output_t.data(),
amax.data(), scale_inv.data(), 0);
nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0);
float ref_amax;
compute_ref<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, &ref_amax,
*(scale.cpu_dptr<float>()));
output_c.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
......
......@@ -23,7 +23,7 @@ namespace {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dbias(const IT *input_h,
const CT *scale_h,
const CT scale,
OT *output_c_h,
OT *output_t_h,
CT *amax_h,
......@@ -31,7 +31,6 @@ void compute_ref_cast_transpose_dbias(const IT *input_h,
const size_t N,
const size_t H) {
CT amax = 0.;
CT scale = *scale_h;
std::vector<CT> acc_dbias(H, 0.);
......@@ -67,17 +66,15 @@ void performTest(const size_t N, const size_t H) {
DType ctype = TypeInfo<CType>::dtype;
Tensor input({N, H}, itype);
Tensor scale({1}, ctype);
Tensor output_c({N, H}, otype);
Tensor output_t({ H, N}, otype);
Tensor amax({1}, ctype);
Tensor scale_inv({1}, ctype);
// dbias has the same data type with "output grad"
Tensor dbias({H}, itype);
fillUniform(input);
fillUniform(scale);
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
......@@ -85,7 +82,7 @@ void performTest(const size_t N, const size_t H) {
CType ref_amax;
compute_ref_cast_transpose_dbias(input.cpu_dptr<IType>(),
scale.cpu_dptr<CType>(),
output_c.scale(),
ref_output_c.get(),
ref_output_t.get(),
&ref_amax,
......@@ -95,12 +92,9 @@ void performTest(const size_t N, const size_t H) {
Tensor workspace;
nvte_cast_transpose_dbias(input.data(),
scale.data(),
output_c.data(),
output_t.data(),
amax.data(),
dbias.data(),
scale_inv.data(),
workspace.data(),
0);
......@@ -108,12 +102,9 @@ void performTest(const size_t N, const size_t H) {
nvte_cast_transpose_dbias(input.data(),
scale.data(),
output_c.data(),
output_t.data(),
amax.data(),
dbias.data(),
scale_inv.data(),
workspace.data(),
0);
......@@ -121,11 +112,12 @@ void performTest(const size_t N, const size_t H) {
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
......
......@@ -32,7 +32,7 @@ CType dgelu(const CType cval) {
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dbias_dgelu(const IT *input,
const IT *gelu_input,
const CT *scale_h,
const CT scale,
OT *output_c,
OT *output_t,
CT *amax_h,
......@@ -40,7 +40,6 @@ void compute_ref_cast_transpose_dbias_dgelu(const IT *input,
const size_t N,
const size_t H) {
CT amax = 0.;
CT scale = *scale_h;
std::vector<CT> acc_dbias(H, 0.);
......@@ -79,18 +78,16 @@ void performTest(const size_t N, const size_t H) {
Tensor input({N, H}, itype);
Tensor gelu_input({N, H}, itype);
Tensor scale({1}, ctype);
Tensor output_c({N, H}, otype);
Tensor output_t({ H, N}, otype);
Tensor amax({1}, ctype);
Tensor scale_inv({1}, ctype);
// dbias has the same data type with "output grad"
Tensor dbias({H}, itype);
fillUniform(input);
fillUniform(gelu_input);
fillUniform(scale);
fillUniform(&input);
fillUniform(&gelu_input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
......@@ -99,7 +96,7 @@ void performTest(const size_t N, const size_t H) {
CType ref_amax;
compute_ref_cast_transpose_dbias_dgelu(input.cpu_dptr<IType>(),
gelu_input.cpu_dptr<IType>(),
scale.cpu_dptr<CType>(),
output_c.scale(),
ref_output_c.get(),
ref_output_t.get(),
&ref_amax,
......@@ -110,12 +107,9 @@ void performTest(const size_t N, const size_t H) {
nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(),
scale.data(),
output_c.data(),
output_t.data(),
amax.data(),
dbias.data(),
scale_inv.data(),
workspace.data(),
0);
......@@ -124,12 +118,9 @@ void performTest(const size_t N, const size_t H) {
nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(),
scale.data(),
output_c.data(),
output_t.data(),
amax.data(),
dbias.data(),
scale_inv.data(),
workspace.data(),
0);
......@@ -137,10 +128,12 @@ void performTest(const size_t N, const size_t H) {
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
......
......@@ -23,16 +23,11 @@ using namespace transformer_engine;
template <typename IT, typename OT, typename CT>
void compute_ref_gelu_cast(const IT *input_h,
OT *output_h,
const CT *scale_h,
const CT scale,
CT *amax_h,
const size_t N,
const size_t H) {
CT amax = 0.;
CT scale = 1;
if (std::is_same<OT, test::fp8e4m3>::value ||
std::is_same<OT, test::fp8e5m2>::value) {
scale = *scale_h;
}
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
......@@ -51,30 +46,22 @@ template <typename IType, typename OType>
void performTestGelu(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
DType ctype = TypeInfo<CType>::dtype;
Tensor input({ N, H }, itype);
Tensor output({ N, H }, otype);
Tensor scale({ 1 }, ctype);
Tensor amax({ 1 }, ctype);
Tensor scale_inv({ 1 }, ctype);
fillUniform(input);
fillUniform(scale);
fillUniform(&input);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H);
nvte_gelu(input.data(), output.data(), scale.data(),
amax.data(), scale_inv.data(), 0);
nvte_gelu(input.data(), output.data(), 0);
float ref_amax;
compute_ref_gelu_cast(input.cpu_dptr<IType>(), ref_output.get(),
scale.cpu_dptr<float>(),
&ref_amax, N, H);
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
......@@ -82,9 +69,9 @@ void performTestGelu(const size_t N, const size_t H) {
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
......
......@@ -132,9 +132,6 @@ void performTest(const size_t N, const size_t H) {
Tensor z({ N, H }, otype);
Tensor gamma({ H }, wtype);
Tensor beta({ H }, wtype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
Tensor mu({ N }, DType::kFloat32);
Tensor rsigma({ N }, DType::kFloat32);
Tensor dz({ N, H }, wtype);
......@@ -143,11 +140,11 @@ void performTest(const size_t N, const size_t H) {
Tensor dbeta({ H }, wtype);
Tensor workspace, barrier, dgamma_part, dbeta_part;
fillUniform(input);
fillUniform(gamma);
fillUniform(beta);
fillUniform(scale);
fillUniform(dz);
fillUniform(&input);
fillUniform(&gamma);
fillUniform(&beta);
setRandomScale(&z);
fillUniform(&dz);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
......@@ -161,14 +158,14 @@ void performTest(const size_t N, const size_t H) {
// Forward kernel
float epsilon = 1e-5;
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon,
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data());
workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon,
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data());
workspace.data(), barrier.data());
// Backward kernel
nvte_layernorm_bwd(dz.data(), input.data(),
......@@ -195,7 +192,7 @@ void performTest(const size_t N, const size_t H) {
float ref_amax;
compute_ref_stats(input.cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? *(scale.cpu_dptr<float>()) : 1.f;
float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(),
gamma.cpu_dptr<WeightType>(),
beta.cpu_dptr<WeightType>(),
......@@ -217,9 +214,9 @@ void performTest(const size_t N, const size_t H) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) {
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
......@@ -60,7 +60,6 @@ void performTest() {
const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype;
const DType ctype = DType::kFloat32;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768},
{768,1},
......@@ -72,8 +71,7 @@ void performTest() {
const size_t num_tensors = tensor_dims.size();
// Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_c_list, output_t_list,
scale_list, amax_list, scale_inv_list;
std::vector<Tensor> input_list, output_c_list, output_t_list;
// Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list;
......@@ -89,16 +87,13 @@ void performTest() {
input_list.emplace_back(Tensor({ height, width }, itype));
output_c_list.emplace_back(Tensor({ height, width }, otype));
output_t_list.emplace_back(Tensor({ width, height }, otype));
scale_list.emplace_back(Tensor({ 1 }, ctype));
amax_list.emplace_back(Tensor({ 1 }, ctype));
scale_inv_list.emplace_back(Tensor({ 1 }, ctype));
auto& input = input_list.back();
auto& scale = scale_list.back();
fillUniform(input);
fillUniform(scale);
*scale.cpu_dptr<float>() += 2.5;
scale.from_cpu();
auto& output_c = output_c_list.back();
auto& output_t = output_t_list.back();
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
ref_input_list.emplace_back(height*width);
ref_output_c_list.emplace_back(height*width);
......@@ -107,7 +102,7 @@ void performTest() {
std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin());
ref_scale_list[tensor_id] = *scale.cpu_dptr<float>();
ref_scale_list[tensor_id] = output_c.scale();
ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width;
}
......@@ -123,11 +118,8 @@ void performTest() {
};
nvte_multi_cast_transpose(num_tensors,
make_nvte_vector(input_list).data(),
make_nvte_vector(scale_list).data(),
make_nvte_vector(output_c_list).data(),
make_nvte_vector(output_t_list).data(),
make_nvte_vector(amax_list).data(),
make_nvte_vector(scale_inv_list).data(),
0);
// Reference implementation
......@@ -145,15 +137,17 @@ void performTest() {
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax",
amax_list[tensor_id],
&ref_amax_list[tensor_id],
output_c_list[tensor_id].amax(),
ref_amax_list[tensor_id],
atol_amax, rtol_amax);
compareResults("scale_inv",
scale_inv_list[tensor_id],
&ref_scale_inv_list[tensor_id],
output_c_list[tensor_id].scale_inv(),
ref_scale_inv_list[tensor_id],
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c",
output_c_list[tensor_id],
......
......@@ -60,29 +60,26 @@ void performTestQ(const size_t N) {
Tensor input({ N }, itype);
Tensor output({ N }, otype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(input);
fillUniform(scale);
fillUniform(&input);
setRandomScale(&output);
nvte_fp8_quantize(input.data(), scale.data(), output.data(), amax.data(), scale_inv.data(), 0);
nvte_fp8_quantize(input.data(), output.data(), 0);
float ref_amax;
compute_ref_q<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(),
N, &ref_amax, *(scale.cpu_dptr<float>()));
N, &ref_amax, output.scale());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>());
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_q", output, ref_output.get(), atol, rtol);
}
......@@ -96,17 +93,15 @@ void performTestDQ(const size_t N) {
Tensor input({ N }, itype);
Tensor output({ N }, otype);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(input);
fillUniform(scale_inv);
fillUniform(&input);
nvte_fp8_dequantize(input.data(), scale_inv.data(), output.data(), 0);
nvte_fp8_dequantize(input.data(), output.data(), 0);
compute_ref_dq<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(),
N, *(scale_inv.cpu_dptr<float>()));
N, input.scale_inv());
cudaDeviceSynchronize();
auto err = cudaGetLastError();
......
......@@ -41,7 +41,7 @@ void performTest(const size_t N, const size_t H) {
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
fillUniform(input);
fillUniform(&input);
nvte_transpose(input.data(), output.data(), 0);
......
......@@ -63,24 +63,87 @@ Tensor::Tensor(const NVTEShape &shape, const DType type) {
size_t total_size = product(shape) * s;
void *dptr = nullptr;
cpu_data_ = nullptr;
amax_cpu_data_ = nullptr;
scale_cpu_data_ = nullptr;
scale_inv_cpu_data_ = nullptr;
float *amax = nullptr, *scale = nullptr, *scale_inv = nullptr;
if (total_size != 0) {
cudaMalloc((void**)&dptr, total_size); // NOLINT(*)
cudaMemset(dptr, 0, total_size);
cpu_data_ = std::make_unique<unsigned char[]>(total_size);
for (size_t i = 0; i < total_size; ++i) {
cpu_data_[i] = 0;
}
tensor_ = TensorWrapper(dptr, shape, type);
}
if (isFp8Type(type)) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float));
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
cudaMemset(scale, 0, sizeof(float));
cudaMalloc((void**)&scale_inv, sizeof(float)); // NOLINT(*)
cudaMemset(scale_inv, 0, sizeof(float));
amax_cpu_data_ = std::make_shared<float>();
*amax_cpu_data_ = 0;
scale_cpu_data_ = std::make_shared<float>();
*scale_cpu_data_ = 0;
scale_inv_cpu_data_ = std::make_shared<float>();
*scale_inv_cpu_data_ = 0;
}
tensor_ = TensorWrapper(dptr, shape, type, amax, scale, scale_inv);
}
void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost);
if (isFp8Type(dtype())) {
cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), sizeof(float),
cudaMemcpyDeviceToHost);
cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float),
cudaMemcpyDeviceToHost);
cudaMemcpy(scale_inv_cpu_data_.get(), tensor_.scale_inv(), sizeof(float),
cudaMemcpyDeviceToHost);
}
}
void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype());
cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice);
if (isFp8Type(dtype())) {
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.scale_inv(), scale_inv_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
}
}
void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) {
NVTE_CHECK(scale_cpu_data_);
*scale_cpu_data_ = scale;
from_cpu();
}
}
void Tensor::set_scale_inv(float scale_inv) {
if (isFp8Type(dtype())) {
NVTE_CHECK(scale_inv_cpu_data_);
*scale_inv_cpu_data_ = scale_inv;
from_cpu();
}
}
void Tensor::shareFP8Meta(const Tensor &other) {
if(isFp8Type(dtype()) && isFp8Type(other.dtype())) {
tensor_ = TensorWrapper(dptr(), shape(), dtype(),
other.tensor_.amax(),
other.tensor_.scale(),
other.tensor_.scale_inv());
to_cpu();
}
}
using std::to_string;
......@@ -141,6 +204,16 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref
);
}
void compareResults(const std::string &name, const float test, const float ref,
double atol, double rtol) {
double t = static_cast<double>(test);
double r = static_cast<double>(ref);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
<< "Mismatch: " << t << " vs " << r;
}
std::pair<double, double> getTolerances(const DType type) {
switch(type) {
case DType::kFloat32:
......@@ -158,17 +231,25 @@ std::pair<double, double> getTolerances(const DType type) {
return {0, 0};
}
void fillUniform(const Tensor &t) {
const size_t size = product(t.shape());
void fillUniform(Tensor *t) {
const size_t size = product(t->shape());
static std::mt19937 gen(12345);
std::uniform_real_distribution<> dis(-2.0, 1.0);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t.dtype(), T, {
T *data = t.cpu_dptr<T>();
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, {
T *data = t->cpu_dptr<T>();
for (size_t i = 0; i < size; ++i) {
data[i] = T(dis(gen));
}
});
t.from_cpu();
t->set_scale_inv(dis(gen));
t->from_cpu();
}
void setRandomScale(Tensor *t) {
static std::mt19937 gen(12345);
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float scale = dis(gen);
t->set_scale(scale);
}
bool isFp8Type(DType type) {
......
......@@ -130,12 +130,45 @@ class Tensor {
return reinterpret_cast<T *>(cpu_data_.get());
}
float amax() const {
if(amax_cpu_data_) {
to_cpu();
return *amax_cpu_data_;
} else {
return 0;
}
}
float scale() const {
if(scale_cpu_data_) {
to_cpu();
return *scale_cpu_data_;
} else {
return 1;
}
}
float scale_inv() const {
if(scale_inv_cpu_data_) {
to_cpu();
return *scale_inv_cpu_data_;
} else {
return 1;
}
}
void to_cpu() const;
void from_cpu() const;
void set_scale(float scale);
void set_scale_inv(float scale_inv);
void shareFP8Meta(const Tensor &other);
private:
TensorWrapper tensor_;
std::unique_ptr<unsigned char[]> cpu_data_;
std::shared_ptr<float> amax_cpu_data_;
std::shared_ptr<float> scale_cpu_data_;
std::shared_ptr<float> scale_inv_cpu_data_;
};
size_t typeToSize(DType type);
......@@ -145,10 +178,13 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);
void compareResults(const std::string &name, const Tensor &test, const void *ref,
double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8);
std::pair<double, double> getTolerances(const DType type);
void fillUniform(const Tensor &t);
void fillUniform(Tensor *t);
void setRandomScale(Tensor *t);
constexpr int THREADS_PER_WARP = 32;
......
......@@ -26,39 +26,24 @@ __device__ inline fp32 gelu(fp32 value, const GELUParam &) {
}
void gelu_cast(const Tensor &input,
const Tensor &scale,
Tensor *output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.shape == output->shape, "Input and output shapes must match.");
const size_t tot_elts = input.shape[1] * input.shape[0];
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element.");
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element.");
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type.");
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX tensor is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv tensor is not allocated.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->dtype, OType,
CheckInputTensor(input, "gelu_input");
CheckOutputTensor(*output, "gelu_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = input.data.shape[1] * input.data.shape[0];
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::GELUParam, detail::gelu>(
reinterpret_cast<const IType*>(input.dptr),
reinterpret_cast<OType*>(output->dptr),
reinterpret_cast<const fp32*>(scale.dptr),
reinterpret_cast<fp32*>(scale_inv->dptr),
reinterpret_cast<fp32*>(amax->dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
......@@ -70,15 +55,9 @@ void gelu_cast(const Tensor &input,
void nvte_gelu(const NVTETensor input,
NVTETensor output,
const NVTETensor scale,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) {
using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream);
}
......@@ -23,12 +23,26 @@
namespace transformer_engine {
struct Tensor {
void* dptr;
struct SimpleTensor {
void *dptr;
std::vector<size_t> shape;
DType dtype;
Tensor() : dptr(nullptr), shape(), dtype(DType::kFloat32) {}
SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype) :
dptr(dptr), shape(shape), dtype(dtype) {}
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
};
struct Tensor {
SimpleTensor data;
SimpleTensor amax;
SimpleTensor scale;
SimpleTensor scale_inv;
Tensor() : data(),
amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32) {}
};
template <typename T>
......@@ -239,62 +253,8 @@ struct TypeInfo{
NVTE_ERROR("Invalid type for 16 bit."); \
}
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
template<>
struct TypeId<fp8e4m3>{
constexpr static uint32_t Value = 3;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
inline size_t product(const std::vector<size_t> &shape) {
size_t ret = 1;
for (const auto &elem : shape) {
......@@ -320,6 +280,11 @@ struct is_fp8<fp8e5m2> : std::true_type {};
size_t typeToSize(const DType type);
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
bool is_fp8_dtype(const DType t);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
......@@ -950,15 +950,15 @@ void scaled_softmax_forward(
float scale_factor,
cudaStream_t stream) {
const int batches = input.shape[0];
const int attn_heads = input.shape[1];
const int query_seq_len = input.shape[2];
const int key_seq_len = input.shape[3];
const int batches = input.data.shape[0];
const int attn_heads = input.data.shape[1];
const int query_seq_len = input.data.shape[2];
const int key_seq_len = input.data.shape[3];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type,
dispatch_scaled_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr),
reinterpret_cast<const softmax_type*>(input.dptr),
reinterpret_cast<softmax_type*>(softmax_results->data.dptr),
reinterpret_cast<const softmax_type*>(input.data.dptr),
scale_factor,
query_seq_len,
key_seq_len,
......@@ -975,17 +975,17 @@ void scaled_softmax_backward(
cudaStream_t stream) {
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.shape[0];
const int attn_heads = output_grads.shape[1];
const int query_seq_len = output_grads.shape[2];
const int key_seq_len = output_grads.shape[3];
const int batches = output_grads.data.shape[0];
const int attn_heads = output_grads.data.shape[1];
const int query_seq_len = output_grads.data.shape[2];
const int key_seq_len = output_grads.data.shape[3];
// Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type,
dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr),
reinterpret_cast<softmax_type*>(output_grads.data.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.data.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.data.dptr),
scale_factor,
query_seq_len,
key_seq_len,
......@@ -1002,17 +1002,17 @@ void scaled_masked_softmax_forward(
float scale_factor,
cudaStream_t stream) {
const int batches = input.shape[0];
const int pad_batches = mask.shape[0];
const int attn_heads = input.shape[1];
const int query_seq_len = input.shape[2];
const int key_seq_len = input.shape[3];
const int batches = input.data.shape[0];
const int pad_batches = mask.data.shape[0];
const int attn_heads = input.data.shape[1];
const int query_seq_len = input.data.shape[2];
const int key_seq_len = input.data.shape[3];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type,
dispatch_scaled_masked_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr),
reinterpret_cast<const softmax_type*>(input.dptr),
reinterpret_cast<const uint8_t*>(mask.dptr),
reinterpret_cast<softmax_type*>(softmax_results->data.dptr),
reinterpret_cast<const softmax_type*>(input.data.dptr),
reinterpret_cast<const uint8_t*>(mask.data.dptr),
scale_factor,
query_seq_len,
key_seq_len,
......@@ -1031,17 +1031,17 @@ void scaled_masked_softmax_backward(
cudaStream_t stream
) {
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.shape[0];
const int attn_heads = output_grads.shape[1];
const int query_seq_len = output_grads.shape[2];
const int key_seq_len = output_grads.shape[3];
const int batches = output_grads.data.shape[0];
const int attn_heads = output_grads.data.shape[1];
const int query_seq_len = output_grads.data.shape[2];
const int key_seq_len = output_grads.data.shape[3];
// Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type,
dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr),
reinterpret_cast<softmax_type*>(output_grads.data.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.data.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.data.dptr),
scale_factor,
query_seq_len,
key_seq_len,
......
......@@ -667,13 +667,13 @@ void scaled_upper_triang_masked_softmax_forward(
float scale_factor,
cudaStream_t stream) {
const int attn_batches = input.shape[0];
const int seq_len = input.shape[1];
const int attn_batches = input.data.shape[0];
const int seq_len = input.data.shape[1];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type,
dispatch_scaled_upper_triang_masked_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr),
reinterpret_cast<const softmax_type*>(input.dptr),
reinterpret_cast<softmax_type*>(softmax_results->data.dptr),
reinterpret_cast<const softmax_type*>(input.data.dptr),
scale_factor,
seq_len,
seq_len,
......@@ -689,15 +689,15 @@ void scaled_upper_triang_masked_softmax_backward(
float scale_factor,
cudaStream_t stream) {
const int attn_batches = output_grads.shape[0];
const int seq_len = output_grads.shape[1];
const int attn_batches = output_grads.data.shape[0];
const int seq_len = output_grads.data.shape[1];
// Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type,
dispatch_scaled_upper_triang_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr),
reinterpret_cast<softmax_type*>(output_grads.data.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.data.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.data.dptr),
scale_factor,
seq_len,
seq_len,
......
......@@ -11,33 +11,67 @@
#include <cublas_v2.h>
#include "../common.h"
namespace {
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDA_R_16F;
case DType::kFloat32:
return CUDA_R_32F;
case DType::kBFloat16:
return CUDA_R_16BF;
case DType::kFloat8E4M3:
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
} // namespace
namespace transformer_engine {
void cublas_gemm(void* A,
void* A_scale_inverse,
void* B,
void *B_scale_inverse,
void* D,
void* bias_ptr,
void* pre_gelu_out,
void cublas_gemm(const Tensor *inputA,
const Tensor *inputB,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
cudaDataType_t A_type,
cudaDataType_t B_type,
cudaDataType_t D_type,
cudaDataType_t bias_type,
cublasOperation_t transa,
cublasOperation_t transb,
bool bias,
bool gelu,
bool grad,
void* workspace,
size_t workspaceSize,
bool use_fp8,
bool accumulate,
bool use_split_accumulator,
cudaStream_t stream
) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB->scale_inv.dptr;
void *D = outputD->data.dptr;
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype);
const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype);
const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion is unavailable right now.
......@@ -190,37 +224,8 @@ void cublas_gemm(void* A,
} // namespace transformer_engine
namespace {
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
case DType::kFloat16:
return CUDA_R_16F;
case DType::kFloat32:
return CUDA_R_32F;
case DType::kBFloat16:
return CUDA_R_16BF;
case DType::kFloat8E4M3:
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
default:
NVTE_ERROR("Invalid type");
}
}
bool is_fp8_dtype(const transformer_engine::DType t) {
return t == transformer_engine::DType::kFloat8E4M3 ||
t == transformer_engine::DType::kFloat8E5M2;
}
} // namespace
void nvte_cublas_gemm(const NVTETensor A,
const NVTETensor A_scale_inverse,
const NVTETensor B,
const NVTETensor B_scale_inverse,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
......@@ -234,16 +239,14 @@ void nvte_cublas_gemm(const NVTETensor A,
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor*>(A);
const Tensor *inputB = reinterpret_cast<const Tensor*>(B);
const Tensor *Ainvscale = reinterpret_cast<const Tensor*>(A_scale_inverse);
const Tensor *Binvscale = reinterpret_cast<const Tensor*>(B_scale_inverse);
Tensor *outputD = reinterpret_cast<Tensor*>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor*>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor*>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor*>(workspace);
const int m = transa ? inputA->shape[0] : inputA->shape[1];
const int k = transa ? inputA->shape[1] : inputA->shape[0];
const int n = transb ? inputB->shape[1] : inputB->shape[0];
const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1];
const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0];
const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0];
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
......@@ -261,23 +264,17 @@ void nvte_cublas_gemm(const NVTETensor A,
NVTE_ERROR("TT layout not allowed.");
}
cublas_gemm(inputA->dptr, Ainvscale->dptr,
inputB->dptr, Binvscale->dptr,
outputD->dptr, biasTensor->dptr,
outputGelu->dptr,
cublas_gemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
get_cuda_dtype(inputA->dtype),
get_cuda_dtype(inputB->dtype),
get_cuda_dtype(outputD->dtype),
get_cuda_dtype(biasTensor->dtype),
(transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N,
biasTensor->dptr != nullptr,
outputGelu->dptr != nullptr,
grad, wspace->dptr,
wspace->shape[0],
is_fp8_dtype(inputA->dtype) || is_fp8_dtype(inputB->dtype),
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
stream);
}
......@@ -20,17 +20,11 @@ extern "C" {
/*! \brief Compute GELU activation of the input.
*
* \param[in] input Input tensor for GELU activation.
* \param[out] output Output tensor.
* \param[in] scale Scaling factor of the output tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_gelu(const NVTETensor input,
NVTETensor output,
const NVTETensor scale,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream);
#ifdef __cplusplus
......
......@@ -20,28 +20,20 @@ extern "C" {
/*! \brief Cast tensor to FP8.
*
* \param[in] input Input tensor to be cast.
* \param[in] scale Scaling factor of the output tensor.
* \param[out] output Output FP8 tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in,out] output Output FP8 tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_quantize(const NVTETensor input,
const NVTETensor scale,
NVTETensor output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream);
/*! \brief Cast tensor from FP8.
*
* \param[in] input Input tensor to be cast.
* \param[in] scale_inv Inverse of the input's scaling factor.
* \param[out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_dequantize(const NVTETensor input,
const NVTETensor scale_inv,
NVTETensor output,
cudaStream_t stream);
......
......@@ -25,12 +25,10 @@ extern "C" {
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] A_scale_inverse The inverse of A matrix' scaling factor.
* \param[in] B The B matrix.
* \param[in] B_scale_inverse The inverse of B matrix' scaling factor.
* \param[out] D Output matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[out] pre_gelu_out Output matrix before GELU activation.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
......@@ -41,9 +39,7 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm(const NVTETensor A,
const NVTETensor A_scale_inverse,
const NVTETensor B,
const NVTETensor B_scale_inverse,
NVTETensor D,
const NVTETensor bias,
NVTETensor pre_gelu_out,
......
......@@ -26,9 +26,8 @@ extern "C" {
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] beta Beta tensor of shape [H].
* \param[in] scale Scaling factor used for output.
* \param[in] epsilon Value added to denominator for numerical stability.
* \param[out] z Output tensor of shape [N, H].
* \param[in,out] z Output tensor of shape [N, H].
* \param[out] mu Mean of the input calculated over the last dimension.
* Shape: [N].
* \param[out] rsigma Inverse of the variance of the input calculated over
......@@ -37,13 +36,10 @@ extern "C" {
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
*/
void nvte_layernorm_fwd(const NVTETensor x,
const NVTETensor gamma,
const NVTETensor beta,
const NVTETensor scale,
const float epsilon,
NVTETensor z,
NVTETensor mu,
......@@ -51,9 +47,7 @@ void nvte_layernorm_fwd(const NVTETensor x,
cudaStream_t stream,
const int multiprocessorCount,
NVTETensor workspace,
NVTETensor barrier,
NVTETensor amax,
NVTETensor scale_inv);
NVTETensor barrier);
/*! \brief Compute backward of LayerNorm.
......
......@@ -59,12 +59,18 @@ typedef void* NVTETensor;
* \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*
* \return A new TE tensor.
*/
NVTETensor nvte_create_tensor(void *dptr,
const NVTEShape shape,
const NVTEDType dtype);
const NVTEDType dtype,
float *amax_dptr,
float *scale_dptr,
float *scale_inv_dptr);
/*! \brief Destroy a TE tensor.
*
......@@ -99,6 +105,30 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor);
*/
void *nvte_tensor_data(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's amax data.
*
* \param[in] tensor Tensor.
*
* \return A pointer to tensor's amax data.
*/
float *nvte_tensor_amax(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's scale data.
*
* \param[in] tensor Tensor.
*
* \return A pointer to tensor's scale data.
*/
float *nvte_tensor_scale(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's inverse of scale data.
*
* \param[in] tensor Tensor.
*
* \return A pointer to tensor's inverse of scale data.
*/
float *nvte_tensor_scale_inv(const NVTETensor tensor);
#ifdef __cplusplus
} // extern "C"
......@@ -138,9 +168,15 @@ class TensorWrapper {
* \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*/
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype) :
tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype))) {}
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype,
float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr) :
tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype),
amax_dptr, scale_dptr, scale_inv_dptr)) {}
/*! \brief Constructs new TensorWrapper.
*
......@@ -151,9 +187,15 @@ class TensorWrapper {
* \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*/
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype) :
TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype) {}
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype,
float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr) :
TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype,
amax_dptr, scale_dptr, scale_inv_dptr) {}
/*! \brief Constructs new empty TensorWrapper.
*
......@@ -229,6 +271,33 @@ class TensorWrapper {
return nvte_tensor_data(tensor_);
}
/*! \brief Get a pointer to the tensor's amax data.
*
* \return A pointer to tensor's amax data.
*/
float *amax() const noexcept {
if (tensor_ == nullptr) return nullptr;
return nvte_tensor_amax(tensor_);
}
/*! \brief Get a pointer to the tensor's scale data.
*
* \return A pointer to tensor's scale data.
*/
float *scale() const noexcept {
if (tensor_ == nullptr) return nullptr;
return nvte_tensor_scale(tensor_);
}
/*! \brief Get a pointer to the tensor's inverse of scale data.
*
* \return A pointer to tensor's inverse of scale data.
*/
float *scale_inv() const noexcept {
if (tensor_ == nullptr) return nullptr;
return nvte_tensor_scale_inv(tensor_);
}
private:
/*! \brief Wrapped NVTETensor. */
NVTETensor tensor_ = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment