Unverified Commit a5ba71f3 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Move the amax/scale/scale_inv into the TE Tensor struct. (#33)



* Move the amax/scale/scale_inv into the TE Tensor struct.
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Handle multi_cast_transpose
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Changed softmax to new Tensor
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass at the cpp tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Round of fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix multi_cast_transpose
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix cast_to_fp8
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 509bf877
...@@ -47,29 +47,28 @@ void performTest(const size_t N, const size_t H) { ...@@ -47,29 +47,28 @@ void performTest(const size_t N, const size_t H) {
Tensor input({ N, H }, itype); Tensor input({ N, H }, itype);
Tensor output_c({ N, H }, otype); Tensor output_c({ N, H }, otype);
Tensor output_t({ H, N }, otype); Tensor output_t({ H, N }, otype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H); std::unique_ptr<OutputType[]> ref_output_c = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H); std::unique_ptr<OutputType[]> ref_output_t = std::make_unique<OutputType[]>(N * H);
fillUniform(input); fillUniform(&input);
fillUniform(scale); setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
nvte_cast_transpose(input.data(), scale.data(), output_c.data(), output_t.data(), nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0);
amax.data(), scale_inv.data(), 0);
float ref_amax; float ref_amax;
compute_ref<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output_c.get(), compute_ref<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output_c.get(),
ref_output_t.get(), N, H, &ref_amax, ref_output_t.get(), N, H, &ref_amax,
*(scale.cpu_dptr<float>())); output_c.scale());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) {
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
......
...@@ -23,7 +23,7 @@ namespace { ...@@ -23,7 +23,7 @@ namespace {
template <typename IT, typename OT, typename CT> template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dbias(const IT *input_h, void compute_ref_cast_transpose_dbias(const IT *input_h,
const CT *scale_h, const CT scale,
OT *output_c_h, OT *output_c_h,
OT *output_t_h, OT *output_t_h,
CT *amax_h, CT *amax_h,
...@@ -31,7 +31,6 @@ void compute_ref_cast_transpose_dbias(const IT *input_h, ...@@ -31,7 +31,6 @@ void compute_ref_cast_transpose_dbias(const IT *input_h,
const size_t N, const size_t N,
const size_t H) { const size_t H) {
CT amax = 0.; CT amax = 0.;
CT scale = *scale_h;
std::vector<CT> acc_dbias(H, 0.); std::vector<CT> acc_dbias(H, 0.);
...@@ -67,17 +66,15 @@ void performTest(const size_t N, const size_t H) { ...@@ -67,17 +66,15 @@ void performTest(const size_t N, const size_t H) {
DType ctype = TypeInfo<CType>::dtype; DType ctype = TypeInfo<CType>::dtype;
Tensor input({N, H}, itype); Tensor input({N, H}, itype);
Tensor scale({1}, ctype);
Tensor output_c({N, H}, otype); Tensor output_c({N, H}, otype);
Tensor output_t({ H, N}, otype); Tensor output_t({ H, N}, otype);
Tensor amax({1}, ctype);
Tensor scale_inv({1}, ctype);
// dbias has the same data type with "output grad" // dbias has the same data type with "output grad"
Tensor dbias({H}, itype); Tensor dbias({H}, itype);
fillUniform(input); fillUniform(&input);
fillUniform(scale); setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H); std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H); std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
...@@ -85,7 +82,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -85,7 +82,7 @@ void performTest(const size_t N, const size_t H) {
CType ref_amax; CType ref_amax;
compute_ref_cast_transpose_dbias(input.cpu_dptr<IType>(), compute_ref_cast_transpose_dbias(input.cpu_dptr<IType>(),
scale.cpu_dptr<CType>(), output_c.scale(),
ref_output_c.get(), ref_output_c.get(),
ref_output_t.get(), ref_output_t.get(),
&ref_amax, &ref_amax,
...@@ -95,12 +92,9 @@ void performTest(const size_t N, const size_t H) { ...@@ -95,12 +92,9 @@ void performTest(const size_t N, const size_t H) {
Tensor workspace; Tensor workspace;
nvte_cast_transpose_dbias(input.data(), nvte_cast_transpose_dbias(input.data(),
scale.data(),
output_c.data(), output_c.data(),
output_t.data(), output_t.data(),
amax.data(),
dbias.data(), dbias.data(),
scale_inv.data(),
workspace.data(), workspace.data(),
0); 0);
...@@ -108,12 +102,9 @@ void performTest(const size_t N, const size_t H) { ...@@ -108,12 +102,9 @@ void performTest(const size_t N, const size_t H) {
nvte_cast_transpose_dbias(input.data(), nvte_cast_transpose_dbias(input.data(),
scale.data(),
output_c.data(), output_c.data(),
output_t.data(), output_t.data(),
amax.data(),
dbias.data(), dbias.data(),
scale_inv.data(),
workspace.data(), workspace.data(),
0); 0);
...@@ -121,11 +112,12 @@ void performTest(const size_t N, const size_t H) { ...@@ -121,11 +112,12 @@ void performTest(const size_t N, const size_t H) {
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) {
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>()); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax); float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
......
...@@ -32,7 +32,7 @@ CType dgelu(const CType cval) { ...@@ -32,7 +32,7 @@ CType dgelu(const CType cval) {
template <typename IT, typename OT, typename CT> template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dbias_dgelu(const IT *input, void compute_ref_cast_transpose_dbias_dgelu(const IT *input,
const IT *gelu_input, const IT *gelu_input,
const CT *scale_h, const CT scale,
OT *output_c, OT *output_c,
OT *output_t, OT *output_t,
CT *amax_h, CT *amax_h,
...@@ -40,7 +40,6 @@ void compute_ref_cast_transpose_dbias_dgelu(const IT *input, ...@@ -40,7 +40,6 @@ void compute_ref_cast_transpose_dbias_dgelu(const IT *input,
const size_t N, const size_t N,
const size_t H) { const size_t H) {
CT amax = 0.; CT amax = 0.;
CT scale = *scale_h;
std::vector<CT> acc_dbias(H, 0.); std::vector<CT> acc_dbias(H, 0.);
...@@ -79,18 +78,16 @@ void performTest(const size_t N, const size_t H) { ...@@ -79,18 +78,16 @@ void performTest(const size_t N, const size_t H) {
Tensor input({N, H}, itype); Tensor input({N, H}, itype);
Tensor gelu_input({N, H}, itype); Tensor gelu_input({N, H}, itype);
Tensor scale({1}, ctype);
Tensor output_c({N, H}, otype); Tensor output_c({N, H}, otype);
Tensor output_t({ H, N}, otype); Tensor output_t({ H, N}, otype);
Tensor amax({1}, ctype);
Tensor scale_inv({1}, ctype);
// dbias has the same data type with "output grad" // dbias has the same data type with "output grad"
Tensor dbias({H}, itype); Tensor dbias({H}, itype);
fillUniform(input); fillUniform(&input);
fillUniform(gelu_input); fillUniform(&gelu_input);
fillUniform(scale); setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H); std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N*H);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H); std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N*H);
...@@ -99,7 +96,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -99,7 +96,7 @@ void performTest(const size_t N, const size_t H) {
CType ref_amax; CType ref_amax;
compute_ref_cast_transpose_dbias_dgelu(input.cpu_dptr<IType>(), compute_ref_cast_transpose_dbias_dgelu(input.cpu_dptr<IType>(),
gelu_input.cpu_dptr<IType>(), gelu_input.cpu_dptr<IType>(),
scale.cpu_dptr<CType>(), output_c.scale(),
ref_output_c.get(), ref_output_c.get(),
ref_output_t.get(), ref_output_t.get(),
&ref_amax, &ref_amax,
...@@ -110,12 +107,9 @@ void performTest(const size_t N, const size_t H) { ...@@ -110,12 +107,9 @@ void performTest(const size_t N, const size_t H) {
nvte_cast_transpose_dbias_dgelu(input.data(), nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(), gelu_input.data(),
scale.data(),
output_c.data(), output_c.data(),
output_t.data(), output_t.data(),
amax.data(),
dbias.data(), dbias.data(),
scale_inv.data(),
workspace.data(), workspace.data(),
0); 0);
...@@ -124,12 +118,9 @@ void performTest(const size_t N, const size_t H) { ...@@ -124,12 +118,9 @@ void performTest(const size_t N, const size_t H) {
nvte_cast_transpose_dbias_dgelu(input.data(), nvte_cast_transpose_dbias_dgelu(input.data(),
gelu_input.data(), gelu_input.data(),
scale.data(),
output_c.data(), output_c.data(),
output_t.data(), output_t.data(),
amax.data(),
dbias.data(), dbias.data(),
scale_inv.data(),
workspace.data(), workspace.data(),
0); 0);
...@@ -137,10 +128,12 @@ void performTest(const size_t N, const size_t H) { ...@@ -137,10 +128,12 @@ void performTest(const size_t N, const size_t H) {
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) {
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>()); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax); float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
......
...@@ -23,16 +23,11 @@ using namespace transformer_engine; ...@@ -23,16 +23,11 @@ using namespace transformer_engine;
template <typename IT, typename OT, typename CT> template <typename IT, typename OT, typename CT>
void compute_ref_gelu_cast(const IT *input_h, void compute_ref_gelu_cast(const IT *input_h,
OT *output_h, OT *output_h,
const CT *scale_h, const CT scale,
CT *amax_h, CT *amax_h,
const size_t N, const size_t N,
const size_t H) { const size_t H) {
CT amax = 0.; CT amax = 0.;
CT scale = 1;
if (std::is_same<OT, test::fp8e4m3>::value ||
std::is_same<OT, test::fp8e5m2>::value) {
scale = *scale_h;
}
for (size_t i = 0; i < N; i++) { for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) { for (size_t j = 0; j < H; j++) {
...@@ -51,30 +46,22 @@ template <typename IType, typename OType> ...@@ -51,30 +46,22 @@ template <typename IType, typename OType>
void performTestGelu(const size_t N, const size_t H) { void performTestGelu(const size_t N, const size_t H) {
using namespace test; using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype; DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype; DType otype = TypeInfo<OType>::dtype;
DType ctype = TypeInfo<CType>::dtype;
Tensor input({ N, H }, itype); Tensor input({ N, H }, itype);
Tensor output({ N, H }, otype); Tensor output({ N, H }, otype);
Tensor scale({ 1 }, ctype);
Tensor amax({ 1 }, ctype);
Tensor scale_inv({ 1 }, ctype);
fillUniform(input); fillUniform(&input);
fillUniform(scale); setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H); std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H);
nvte_gelu(input.data(), output.data(), scale.data(), nvte_gelu(input.data(), output.data(), 0);
amax.data(), scale_inv.data(), 0);
float ref_amax; float ref_amax;
compute_ref_gelu_cast(input.cpu_dptr<IType>(), ref_output.get(), compute_ref_gelu_cast(input.cpu_dptr<IType>(), ref_output.get(),
scale.cpu_dptr<float>(), output.scale(), &ref_amax, N, H);
&ref_amax, N, H);
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
...@@ -82,9 +69,9 @@ void performTestGelu(const size_t N, const size_t H) { ...@@ -82,9 +69,9 @@ void performTestGelu(const size_t N, const size_t H) {
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>()); float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax); compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol); compareResults("output_gelu", output, ref_output.get(), atol, rtol);
......
...@@ -132,9 +132,6 @@ void performTest(const size_t N, const size_t H) { ...@@ -132,9 +132,6 @@ void performTest(const size_t N, const size_t H) {
Tensor z({ N, H }, otype); Tensor z({ N, H }, otype);
Tensor gamma({ H }, wtype); Tensor gamma({ H }, wtype);
Tensor beta({ H }, wtype); Tensor beta({ H }, wtype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
Tensor mu({ N }, DType::kFloat32); Tensor mu({ N }, DType::kFloat32);
Tensor rsigma({ N }, DType::kFloat32); Tensor rsigma({ N }, DType::kFloat32);
Tensor dz({ N, H }, wtype); Tensor dz({ N, H }, wtype);
...@@ -143,11 +140,11 @@ void performTest(const size_t N, const size_t H) { ...@@ -143,11 +140,11 @@ void performTest(const size_t N, const size_t H) {
Tensor dbeta({ H }, wtype); Tensor dbeta({ H }, wtype);
Tensor workspace, barrier, dgamma_part, dbeta_part; Tensor workspace, barrier, dgamma_part, dbeta_part;
fillUniform(input); fillUniform(&input);
fillUniform(gamma); fillUniform(&gamma);
fillUniform(beta); fillUniform(&beta);
fillUniform(scale); setRandomScale(&z);
fillUniform(dz); fillUniform(&dz);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H); std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N); std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
...@@ -161,14 +158,14 @@ void performTest(const size_t N, const size_t H) { ...@@ -161,14 +158,14 @@ void performTest(const size_t N, const size_t H) {
// Forward kernel // Forward kernel
float epsilon = 1e-5; float epsilon = 1e-5;
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data()); workspace.data(), barrier.data());
workspace = Tensor(workspace.shape(), workspace.dtype()); workspace = Tensor(workspace.shape(), workspace.dtype());
barrier = Tensor(barrier.shape(), barrier.dtype()); barrier = Tensor(barrier.shape(), barrier.dtype());
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), scale.data(), epsilon, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount,
workspace.data(), barrier.data(), amax.data(), scale_inv.data()); workspace.data(), barrier.data());
// Backward kernel // Backward kernel
nvte_layernorm_bwd(dz.data(), input.data(), nvte_layernorm_bwd(dz.data(), input.data(),
...@@ -195,7 +192,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -195,7 +192,7 @@ void performTest(const size_t N, const size_t H) {
float ref_amax; float ref_amax;
compute_ref_stats(input.cpu_dptr<InputType>(), ref_mu.get(), compute_ref_stats(input.cpu_dptr<InputType>(), ref_mu.get(),
ref_rsigma.get(), N, H, epsilon); ref_rsigma.get(), N, H, epsilon);
float ref_scale = isFp8Type(otype) ? *(scale.cpu_dptr<float>()) : 1.f; float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
compute_ref_output(input.cpu_dptr<InputType>(), compute_ref_output(input.cpu_dptr<InputType>(),
gamma.cpu_dptr<WeightType>(), gamma.cpu_dptr<WeightType>(),
beta.cpu_dptr<WeightType>(), beta.cpu_dptr<WeightType>(),
...@@ -217,9 +214,9 @@ void performTest(const size_t N, const size_t H) { ...@@ -217,9 +214,9 @@ void performTest(const size_t N, const size_t H) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
if (isFp8Type(otype)) { if (isFp8Type(otype)) {
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax); compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>()); float ref_scale_inv = 1.f / z.scale();
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax); compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
} }
auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
......
...@@ -60,7 +60,6 @@ void performTest() { ...@@ -60,7 +60,6 @@ void performTest() {
const DType itype = TypeInfo<InputType>::dtype; const DType itype = TypeInfo<InputType>::dtype;
const DType otype = TypeInfo<OutputType>::dtype; const DType otype = TypeInfo<OutputType>::dtype;
const DType ctype = DType::kFloat32;
const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1}, const std::vector<std::pair<size_t, size_t>> tensor_dims = {{1,1},
{1,768}, {1,768},
{768,1}, {768,1},
...@@ -72,8 +71,7 @@ void performTest() { ...@@ -72,8 +71,7 @@ void performTest() {
const size_t num_tensors = tensor_dims.size(); const size_t num_tensors = tensor_dims.size();
// Buffers for Transformer Engine implementation // Buffers for Transformer Engine implementation
std::vector<Tensor> input_list, output_c_list, output_t_list, std::vector<Tensor> input_list, output_c_list, output_t_list;
scale_list, amax_list, scale_inv_list;
// Buffers for reference implementation // Buffers for reference implementation
std::vector<std::vector<InputType>> ref_input_list; std::vector<std::vector<InputType>> ref_input_list;
...@@ -89,16 +87,13 @@ void performTest() { ...@@ -89,16 +87,13 @@ void performTest() {
input_list.emplace_back(Tensor({ height, width }, itype)); input_list.emplace_back(Tensor({ height, width }, itype));
output_c_list.emplace_back(Tensor({ height, width }, otype)); output_c_list.emplace_back(Tensor({ height, width }, otype));
output_t_list.emplace_back(Tensor({ width, height }, otype)); output_t_list.emplace_back(Tensor({ width, height }, otype));
scale_list.emplace_back(Tensor({ 1 }, ctype));
amax_list.emplace_back(Tensor({ 1 }, ctype));
scale_inv_list.emplace_back(Tensor({ 1 }, ctype));
auto& input = input_list.back(); auto& input = input_list.back();
auto& scale = scale_list.back(); auto& output_c = output_c_list.back();
fillUniform(input); auto& output_t = output_t_list.back();
fillUniform(scale); fillUniform(&input);
*scale.cpu_dptr<float>() += 2.5; setRandomScale(&output_c);
scale.from_cpu(); output_t.shareFP8Meta(output_c);
ref_input_list.emplace_back(height*width); ref_input_list.emplace_back(height*width);
ref_output_c_list.emplace_back(height*width); ref_output_c_list.emplace_back(height*width);
...@@ -107,7 +102,7 @@ void performTest() { ...@@ -107,7 +102,7 @@ void performTest() {
std::copy(input.cpu_dptr<InputType>(), std::copy(input.cpu_dptr<InputType>(),
input.cpu_dptr<InputType>() + height * width, input.cpu_dptr<InputType>() + height * width,
ref_input_list.back().begin()); ref_input_list.back().begin());
ref_scale_list[tensor_id] = *scale.cpu_dptr<float>(); ref_scale_list[tensor_id] = output_c.scale();
ref_height_list[tensor_id] = height; ref_height_list[tensor_id] = height;
ref_width_list[tensor_id] = width; ref_width_list[tensor_id] = width;
} }
...@@ -123,11 +118,8 @@ void performTest() { ...@@ -123,11 +118,8 @@ void performTest() {
}; };
nvte_multi_cast_transpose(num_tensors, nvte_multi_cast_transpose(num_tensors,
make_nvte_vector(input_list).data(), make_nvte_vector(input_list).data(),
make_nvte_vector(scale_list).data(),
make_nvte_vector(output_c_list).data(), make_nvte_vector(output_c_list).data(),
make_nvte_vector(output_t_list).data(), make_nvte_vector(output_t_list).data(),
make_nvte_vector(amax_list).data(),
make_nvte_vector(scale_inv_list).data(),
0); 0);
// Reference implementation // Reference implementation
...@@ -145,15 +137,17 @@ void performTest() { ...@@ -145,15 +137,17 @@ void performTest() {
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) {
compareResults("amax", auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
amax_list[tensor_id], compareResults("amax",
&ref_amax_list[tensor_id], output_c_list[tensor_id].amax(),
atol_amax, rtol_amax); ref_amax_list[tensor_id],
compareResults("scale_inv", atol_amax, rtol_amax);
scale_inv_list[tensor_id], compareResults("scale_inv",
&ref_scale_inv_list[tensor_id], output_c_list[tensor_id].scale_inv(),
atol_amax, rtol_amax); ref_scale_inv_list[tensor_id],
atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", compareResults("output_c",
output_c_list[tensor_id], output_c_list[tensor_id],
......
...@@ -60,29 +60,26 @@ void performTestQ(const size_t N) { ...@@ -60,29 +60,26 @@ void performTestQ(const size_t N) {
Tensor input({ N }, itype); Tensor input({ N }, itype);
Tensor output({ N }, otype); Tensor output({ N }, otype);
Tensor scale({ 1 }, DType::kFloat32);
Tensor amax({ 1 }, DType::kFloat32);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N); std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(input); fillUniform(&input);
fillUniform(scale); setRandomScale(&output);
nvte_fp8_quantize(input.data(), scale.data(), output.data(), amax.data(), scale_inv.data(), 0); nvte_fp8_quantize(input.data(), output.data(), 0);
float ref_amax; float ref_amax;
compute_ref_q<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(), compute_ref_q<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(),
N, &ref_amax, *(scale.cpu_dptr<float>())); N, &ref_amax, output.scale());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", amax, &ref_amax, atol_amax, rtol_amax); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / (*scale.cpu_dptr<float>()); float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", scale_inv, &ref_scale_inv, atol_amax, rtol_amax); compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
auto [atol, rtol] = getTolerances(otype); auto [atol, rtol] = getTolerances(otype);
compareResults("output_q", output, ref_output.get(), atol, rtol); compareResults("output_q", output, ref_output.get(), atol, rtol);
} }
...@@ -96,17 +93,15 @@ void performTestDQ(const size_t N) { ...@@ -96,17 +93,15 @@ void performTestDQ(const size_t N) {
Tensor input({ N }, itype); Tensor input({ N }, itype);
Tensor output({ N }, otype); Tensor output({ N }, otype);
Tensor scale_inv({ 1 }, DType::kFloat32);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N); std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N);
fillUniform(input); fillUniform(&input);
fillUniform(scale_inv);
nvte_fp8_dequantize(input.data(), scale_inv.data(), output.data(), 0); nvte_fp8_dequantize(input.data(), output.data(), 0);
compute_ref_dq<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(), compute_ref_dq<InputType, OutputType>(input.cpu_dptr<InputType>(), ref_output.get(),
N, *(scale_inv.cpu_dptr<float>())); N, input.scale_inv());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto err = cudaGetLastError(); auto err = cudaGetLastError();
......
...@@ -41,7 +41,7 @@ void performTest(const size_t N, const size_t H) { ...@@ -41,7 +41,7 @@ void performTest(const size_t N, const size_t H) {
std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H); std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
fillUniform(input); fillUniform(&input);
nvte_transpose(input.data(), output.data(), 0); nvte_transpose(input.data(), output.data(), 0);
......
...@@ -63,24 +63,87 @@ Tensor::Tensor(const NVTEShape &shape, const DType type) { ...@@ -63,24 +63,87 @@ Tensor::Tensor(const NVTEShape &shape, const DType type) {
size_t total_size = product(shape) * s; size_t total_size = product(shape) * s;
void *dptr = nullptr; void *dptr = nullptr;
cpu_data_ = nullptr; cpu_data_ = nullptr;
amax_cpu_data_ = nullptr;
scale_cpu_data_ = nullptr;
scale_inv_cpu_data_ = nullptr;
float *amax = nullptr, *scale = nullptr, *scale_inv = nullptr;
if (total_size != 0) { if (total_size != 0) {
cudaMalloc((void**)&dptr, total_size); // NOLINT(*) cudaMalloc((void**)&dptr, total_size); // NOLINT(*)
cudaMemset(dptr, 0, total_size); cudaMemset(dptr, 0, total_size);
cpu_data_ = std::make_unique<unsigned char[]>(total_size); cpu_data_ = std::make_unique<unsigned char[]>(total_size);
for (size_t i = 0; i < total_size; ++i) {
cpu_data_[i] = 0;
}
} }
tensor_ = TensorWrapper(dptr, shape, type); if (isFp8Type(type)) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
cudaMemset(amax, 0, sizeof(float));
cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
cudaMemset(scale, 0, sizeof(float));
cudaMalloc((void**)&scale_inv, sizeof(float)); // NOLINT(*)
cudaMemset(scale_inv, 0, sizeof(float));
amax_cpu_data_ = std::make_shared<float>();
*amax_cpu_data_ = 0;
scale_cpu_data_ = std::make_shared<float>();
*scale_cpu_data_ = 0;
scale_inv_cpu_data_ = std::make_shared<float>();
*scale_inv_cpu_data_ = 0;
}
tensor_ = TensorWrapper(dptr, shape, type, amax, scale, scale_inv);
} }
void Tensor::to_cpu() const { void Tensor::to_cpu() const {
const NVTEShape s = tensor_.shape(); const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype()); const size_t size = product(s) * typeToSize(tensor_.dtype());
cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost); cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost);
if (isFp8Type(dtype())) {
cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), sizeof(float),
cudaMemcpyDeviceToHost);
cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float),
cudaMemcpyDeviceToHost);
cudaMemcpy(scale_inv_cpu_data_.get(), tensor_.scale_inv(), sizeof(float),
cudaMemcpyDeviceToHost);
}
} }
void Tensor::from_cpu() const { void Tensor::from_cpu() const {
const NVTEShape s = tensor_.shape(); const NVTEShape s = tensor_.shape();
const size_t size = product(s) * typeToSize(tensor_.dtype()); const size_t size = product(s) * typeToSize(tensor_.dtype());
cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice); cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice);
if (isFp8Type(dtype())) {
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
cudaMemcpy(tensor_.scale_inv(), scale_inv_cpu_data_.get(), sizeof(float),
cudaMemcpyHostToDevice);
}
}
void Tensor::set_scale(float scale) {
if (isFp8Type(dtype())) {
NVTE_CHECK(scale_cpu_data_);
*scale_cpu_data_ = scale;
from_cpu();
}
}
void Tensor::set_scale_inv(float scale_inv) {
if (isFp8Type(dtype())) {
NVTE_CHECK(scale_inv_cpu_data_);
*scale_inv_cpu_data_ = scale_inv;
from_cpu();
}
}
void Tensor::shareFP8Meta(const Tensor &other) {
if(isFp8Type(dtype()) && isFp8Type(other.dtype())) {
tensor_ = TensorWrapper(dptr(), shape(), dtype(),
other.tensor_.amax(),
other.tensor_.scale(),
other.tensor_.scale_inv());
to_cpu();
}
} }
using std::to_string; using std::to_string;
...@@ -141,6 +204,16 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref ...@@ -141,6 +204,16 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref
); );
} }
void compareResults(const std::string &name, const float test, const float ref,
double atol, double rtol) {
double t = static_cast<double>(test);
double r = static_cast<double>(ref);
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
<< "Mismatch: " << t << " vs " << r;
}
std::pair<double, double> getTolerances(const DType type) { std::pair<double, double> getTolerances(const DType type) {
switch(type) { switch(type) {
case DType::kFloat32: case DType::kFloat32:
...@@ -158,17 +231,25 @@ std::pair<double, double> getTolerances(const DType type) { ...@@ -158,17 +231,25 @@ std::pair<double, double> getTolerances(const DType type) {
return {0, 0}; return {0, 0};
} }
void fillUniform(const Tensor &t) { void fillUniform(Tensor *t) {
const size_t size = product(t.shape()); const size_t size = product(t->shape());
static std::mt19937 gen(12345); static std::mt19937 gen(12345);
std::uniform_real_distribution<> dis(-2.0, 1.0); std::uniform_real_distribution<> dis(-2.0, 1.0);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t.dtype(), T, { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, {
T *data = t.cpu_dptr<T>(); T *data = t->cpu_dptr<T>();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
data[i] = T(dis(gen)); data[i] = T(dis(gen));
} }
}); });
t.from_cpu(); t->set_scale_inv(dis(gen));
t->from_cpu();
}
void setRandomScale(Tensor *t) {
static std::mt19937 gen(12345);
std::uniform_real_distribution<> dis(-2.0, 1.0);
const float scale = dis(gen);
t->set_scale(scale);
} }
bool isFp8Type(DType type) { bool isFp8Type(DType type) {
......
...@@ -130,12 +130,45 @@ class Tensor { ...@@ -130,12 +130,45 @@ class Tensor {
return reinterpret_cast<T *>(cpu_data_.get()); return reinterpret_cast<T *>(cpu_data_.get());
} }
float amax() const {
if(amax_cpu_data_) {
to_cpu();
return *amax_cpu_data_;
} else {
return 0;
}
}
float scale() const {
if(scale_cpu_data_) {
to_cpu();
return *scale_cpu_data_;
} else {
return 1;
}
}
float scale_inv() const {
if(scale_inv_cpu_data_) {
to_cpu();
return *scale_inv_cpu_data_;
} else {
return 1;
}
}
void to_cpu() const; void to_cpu() const;
void from_cpu() const; void from_cpu() const;
void set_scale(float scale);
void set_scale_inv(float scale_inv);
void shareFP8Meta(const Tensor &other);
private: private:
TensorWrapper tensor_; TensorWrapper tensor_;
std::unique_ptr<unsigned char[]> cpu_data_; std::unique_ptr<unsigned char[]> cpu_data_;
std::shared_ptr<float> amax_cpu_data_;
std::shared_ptr<float> scale_cpu_data_;
std::shared_ptr<float> scale_inv_cpu_data_;
}; };
size_t typeToSize(DType type); size_t typeToSize(DType type);
...@@ -145,10 +178,13 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); ...@@ -145,10 +178,13 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);
void compareResults(const std::string &name, const Tensor &test, const void *ref, void compareResults(const std::string &name, const Tensor &test, const void *ref,
double atol = 1e-5, double rtol = 1e-8); double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8);
std::pair<double, double> getTolerances(const DType type); std::pair<double, double> getTolerances(const DType type);
void fillUniform(const Tensor &t); void fillUniform(Tensor *t);
void setRandomScale(Tensor *t);
constexpr int THREADS_PER_WARP = 32; constexpr int THREADS_PER_WARP = 32;
......
...@@ -26,39 +26,24 @@ __device__ inline fp32 gelu(fp32 value, const GELUParam &) { ...@@ -26,39 +26,24 @@ __device__ inline fp32 gelu(fp32 value, const GELUParam &) {
} }
void gelu_cast(const Tensor &input, void gelu_cast(const Tensor &input,
const Tensor &scale,
Tensor *output, Tensor *output,
Tensor *amax,
Tensor *scale_inv,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(input.shape.size() == 2, "Input must have 2 dimensions."); CheckInputTensor(input, "gelu_input");
NVTE_CHECK(output->shape.size() == 2, "Output must have 2 dimensions."); CheckOutputTensor(*output, "gelu_output");
NVTE_CHECK(input.shape == output->shape, "Input and output shapes must match."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
const size_t tot_elts = input.shape[1] * input.shape[0]; NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX tensor must have 1 element."); const size_t tot_elts = input.data.shape[1] * input.data.shape[0];
NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX tensor must have Float32 type.");
NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale tensor must have 1 element."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale tensor must have Float32 type."); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 },
"scale_inv tensor must have 1 element.");
NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "scale_inv tensor must have Float32 type.");
NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(amax->dptr != nullptr, "AMAX tensor is not allocated.");
NVTE_CHECK(scale_inv->dptr != nullptr, "scale_inv tensor is not allocated.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->dtype, OType,
constexpr int nvec = 32 / sizeof(IType); constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::GELUParam, detail::gelu>( VectorizedUnaryKernelLauncher<nvec, detail::GELUParam, detail::gelu>(
reinterpret_cast<const IType*>(input.dptr), reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->dptr), reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(scale.dptr), reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(scale_inv->dptr), reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(amax->dptr), reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts, tot_elts,
{}, {},
stream); stream);
...@@ -70,15 +55,9 @@ void gelu_cast(const Tensor &input, ...@@ -70,15 +55,9 @@ void gelu_cast(const Tensor &input,
void nvte_gelu(const NVTETensor input, void nvte_gelu(const NVTETensor input,
NVTETensor output, NVTETensor output,
const NVTETensor scale,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
gelu_cast(*reinterpret_cast<const Tensor*>(input), gelu_cast(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
reinterpret_cast<Tensor*>(amax),
reinterpret_cast<Tensor*>(scale_inv),
stream); stream);
} }
...@@ -23,12 +23,26 @@ ...@@ -23,12 +23,26 @@
namespace transformer_engine { namespace transformer_engine {
struct Tensor { struct SimpleTensor {
void* dptr; void *dptr;
std::vector<size_t> shape; std::vector<size_t> shape;
DType dtype; DType dtype;
SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype) :
dptr(dptr), shape(shape), dtype(dtype) {}
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
};
Tensor() : dptr(nullptr), shape(), dtype(DType::kFloat32) {} struct Tensor {
SimpleTensor data;
SimpleTensor amax;
SimpleTensor scale;
SimpleTensor scale_inv;
Tensor() : data(),
amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32) {}
}; };
template <typename T> template <typename T>
...@@ -239,62 +253,8 @@ struct TypeInfo{ ...@@ -239,62 +253,8 @@ struct TypeInfo{
NVTE_ERROR("Invalid type for 16 bit."); \ NVTE_ERROR("Invalid type for 16 bit."); \
} }
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
template<>
struct TypeId<fp8e4m3>{
constexpr static uint32_t Value = 3;
};
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 4>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 6>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value |
OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
inline size_t product(const std::vector<size_t> &shape) { inline size_t product(const std::vector<size_t> &shape) {
size_t ret = 1; size_t ret = 1;
for (const auto &elem : shape) { for (const auto &elem : shape) {
...@@ -320,6 +280,11 @@ struct is_fp8<fp8e5m2> : std::true_type {}; ...@@ -320,6 +280,11 @@ struct is_fp8<fp8e5m2> : std::true_type {};
size_t typeToSize(const DType type); size_t typeToSize(const DType type);
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
bool is_fp8_dtype(const DType t);
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
...@@ -102,7 +102,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { ...@@ -102,7 +102,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
/* /*
* Extended softmax (from native aten pytorch) with following additional features * Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling * 1) input scaling
*/ */
template <typename input_t, typename output_t, typename acc_t, int log2_elements> template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward( __global__ void scaled_softmax_warp_forward(
output_t *dst, output_t *dst,
...@@ -215,7 +215,7 @@ __global__ void scaled_softmax_warp_forward( ...@@ -215,7 +215,7 @@ __global__ void scaled_softmax_warp_forward(
* Extended softmax (from native aten pytorch) with following additional features * Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling * 1) input scaling
* 2) Explicit masking * 2) Explicit masking
*/ */
template <typename input_t, typename output_t, typename acc_t, int log2_elements> template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward( __global__ void scaled_masked_softmax_warp_forward(
output_t *dst, output_t *dst,
...@@ -950,15 +950,15 @@ void scaled_softmax_forward( ...@@ -950,15 +950,15 @@ void scaled_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream) { cudaStream_t stream) {
const int batches = input.shape[0]; const int batches = input.data.shape[0];
const int attn_heads = input.shape[1]; const int attn_heads = input.data.shape[1];
const int query_seq_len = input.shape[2]; const int query_seq_len = input.data.shape[2];
const int key_seq_len = input.shape[3]; const int key_seq_len = input.data.shape[3];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type,
dispatch_scaled_softmax_forward<softmax_type, softmax_type, float>( dispatch_scaled_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr), reinterpret_cast<softmax_type*>(softmax_results->data.dptr),
reinterpret_cast<const softmax_type*>(input.dptr), reinterpret_cast<const softmax_type*>(input.data.dptr),
scale_factor, scale_factor,
query_seq_len, query_seq_len,
key_seq_len, key_seq_len,
...@@ -975,17 +975,17 @@ void scaled_softmax_backward( ...@@ -975,17 +975,17 @@ void scaled_softmax_backward(
cudaStream_t stream) { cudaStream_t stream) {
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.shape[0]; const int batches = output_grads.data.shape[0];
const int attn_heads = output_grads.shape[1]; const int attn_heads = output_grads.data.shape[1];
const int query_seq_len = output_grads.shape[2]; const int query_seq_len = output_grads.data.shape[2];
const int key_seq_len = output_grads.shape[3]; const int key_seq_len = output_grads.data.shape[3];
// Softmax Grad // Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type,
dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>( dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type*>(output_grads.data.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.dptr), reinterpret_cast<softmax_type const*>(incoming_grads.data.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr), reinterpret_cast<softmax_type const*>(softmax_results.data.dptr),
scale_factor, scale_factor,
query_seq_len, query_seq_len,
key_seq_len, key_seq_len,
...@@ -1002,17 +1002,17 @@ void scaled_masked_softmax_forward( ...@@ -1002,17 +1002,17 @@ void scaled_masked_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream) { cudaStream_t stream) {
const int batches = input.shape[0]; const int batches = input.data.shape[0];
const int pad_batches = mask.shape[0]; const int pad_batches = mask.data.shape[0];
const int attn_heads = input.shape[1]; const int attn_heads = input.data.shape[1];
const int query_seq_len = input.shape[2]; const int query_seq_len = input.data.shape[2];
const int key_seq_len = input.shape[3]; const int key_seq_len = input.data.shape[3];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type,
dispatch_scaled_masked_softmax_forward<softmax_type, softmax_type, float>( dispatch_scaled_masked_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr), reinterpret_cast<softmax_type*>(softmax_results->data.dptr),
reinterpret_cast<const softmax_type*>(input.dptr), reinterpret_cast<const softmax_type*>(input.data.dptr),
reinterpret_cast<const uint8_t*>(mask.dptr), reinterpret_cast<const uint8_t*>(mask.data.dptr),
scale_factor, scale_factor,
query_seq_len, query_seq_len,
key_seq_len, key_seq_len,
...@@ -1031,17 +1031,17 @@ void scaled_masked_softmax_backward( ...@@ -1031,17 +1031,17 @@ void scaled_masked_softmax_backward(
cudaStream_t stream cudaStream_t stream
) { ) {
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] // output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.shape[0]; const int batches = output_grads.data.shape[0];
const int attn_heads = output_grads.shape[1]; const int attn_heads = output_grads.data.shape[1];
const int query_seq_len = output_grads.shape[2]; const int query_seq_len = output_grads.data.shape[2];
const int key_seq_len = output_grads.shape[3]; const int key_seq_len = output_grads.data.shape[3];
// Softmax Grad // Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type,
dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>( dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type*>(output_grads.data.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.dptr), reinterpret_cast<softmax_type const*>(incoming_grads.data.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr), reinterpret_cast<softmax_type const*>(softmax_results.data.dptr),
scale_factor, scale_factor,
query_seq_len, query_seq_len,
key_seq_len, key_seq_len,
......
...@@ -667,13 +667,13 @@ void scaled_upper_triang_masked_softmax_forward( ...@@ -667,13 +667,13 @@ void scaled_upper_triang_masked_softmax_forward(
float scale_factor, float scale_factor,
cudaStream_t stream) { cudaStream_t stream) {
const int attn_batches = input.shape[0]; const int attn_batches = input.data.shape[0];
const int seq_len = input.shape[1]; const int seq_len = input.data.shape[1];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type,
dispatch_scaled_upper_triang_masked_softmax_forward<softmax_type, softmax_type, float>( dispatch_scaled_upper_triang_masked_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr), reinterpret_cast<softmax_type*>(softmax_results->data.dptr),
reinterpret_cast<const softmax_type*>(input.dptr), reinterpret_cast<const softmax_type*>(input.data.dptr),
scale_factor, scale_factor,
seq_len, seq_len,
seq_len, seq_len,
...@@ -689,15 +689,15 @@ void scaled_upper_triang_masked_softmax_backward( ...@@ -689,15 +689,15 @@ void scaled_upper_triang_masked_softmax_backward(
float scale_factor, float scale_factor,
cudaStream_t stream) { cudaStream_t stream) {
const int attn_batches = output_grads.shape[0]; const int attn_batches = output_grads.data.shape[0];
const int seq_len = output_grads.shape[1]; const int seq_len = output_grads.data.shape[1];
// Softmax Grad // Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type,
dispatch_scaled_upper_triang_masked_softmax_backward<softmax_type, softmax_type, float>( dispatch_scaled_upper_triang_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr), reinterpret_cast<softmax_type*>(output_grads.data.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.dptr), reinterpret_cast<softmax_type const*>(incoming_grads.data.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr), reinterpret_cast<softmax_type const*>(softmax_results.data.dptr),
scale_factor, scale_factor,
seq_len, seq_len,
seq_len, seq_len,
......
...@@ -20,17 +20,11 @@ extern "C" { ...@@ -20,17 +20,11 @@ extern "C" {
/*! \brief Compute GELU activation of the input. /*! \brief Compute GELU activation of the input.
* *
* \param[in] input Input tensor for GELU activation. * \param[in] input Input tensor for GELU activation.
* \param[out] output Output tensor. * \param[in,out] output Output tensor.
* \param[in] scale Scaling factor of the output tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_gelu(const NVTETensor input, void nvte_gelu(const NVTETensor input,
NVTETensor output, NVTETensor output,
const NVTETensor scale,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream); cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -20,28 +20,20 @@ extern "C" { ...@@ -20,28 +20,20 @@ extern "C" {
/*! \brief Cast tensor to FP8. /*! \brief Cast tensor to FP8.
* *
* \param[in] input Input tensor to be cast. * \param[in] input Input tensor to be cast.
* \param[in] scale Scaling factor of the output tensor. * \param[in,out] output Output FP8 tensor.
* \param[out] output Output FP8 tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fp8_quantize(const NVTETensor input, void nvte_fp8_quantize(const NVTETensor input,
const NVTETensor scale,
NVTETensor output, NVTETensor output,
NVTETensor amax,
NVTETensor scale_inv,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Cast tensor from FP8. /*! \brief Cast tensor from FP8.
* *
* \param[in] input Input tensor to be cast. * \param[in] input Input tensor to be cast.
* \param[in] scale_inv Inverse of the input's scaling factor.
* \param[out] output Output tensor. * \param[out] output Output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fp8_dequantize(const NVTETensor input, void nvte_fp8_dequantize(const NVTETensor input,
const NVTETensor scale_inv,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
......
...@@ -25,12 +25,10 @@ extern "C" { ...@@ -25,12 +25,10 @@ extern "C" {
* - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
* *
* \param[in] A The A matrix. * \param[in] A The A matrix.
* \param[in] A_scale_inverse The inverse of A matrix' scaling factor.
* \param[in] B The B matrix. * \param[in] B The B matrix.
* \param[in] B_scale_inverse The inverse of B matrix' scaling factor. * \param[in,out] D Output matrix.
* \param[out] D Output matrix.
* \param[in] bias Bias tensor. * \param[in] bias Bias tensor.
* \param[out] pre_gelu_out Output matrix before GELU activation. * \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed. * \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed. * \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the * \param[in] grad Whether this operation is part of the
...@@ -41,9 +39,7 @@ extern "C" { ...@@ -41,9 +39,7 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cublas_gemm(const NVTETensor A, void nvte_cublas_gemm(const NVTETensor A,
const NVTETensor A_scale_inverse,
const NVTETensor B, const NVTETensor B,
const NVTETensor B_scale_inverse,
NVTETensor D, NVTETensor D,
const NVTETensor bias, const NVTETensor bias,
NVTETensor pre_gelu_out, NVTETensor pre_gelu_out,
......
...@@ -26,9 +26,8 @@ extern "C" { ...@@ -26,9 +26,8 @@ extern "C" {
* \param[in] x Input tensor of shape [N, H]. * \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H]. * \param[in] gamma Gamma tensor of shape [H].
* \param[in] beta Beta tensor of shape [H]. * \param[in] beta Beta tensor of shape [H].
* \param[in] scale Scaling factor used for output.
* \param[in] epsilon Value added to denominator for numerical stability. * \param[in] epsilon Value added to denominator for numerical stability.
* \param[out] z Output tensor of shape [N, H]. * \param[in,out] z Output tensor of shape [N, H].
* \param[out] mu Mean of the input calculated over the last dimension. * \param[out] mu Mean of the input calculated over the last dimension.
* Shape: [N]. * Shape: [N].
* \param[out] rsigma Inverse of the variance of the input calculated over * \param[out] rsigma Inverse of the variance of the input calculated over
...@@ -37,13 +36,10 @@ extern "C" { ...@@ -37,13 +36,10 @@ extern "C" {
* \param[in] multiprocessorCount Number of SMs in the device. * \param[in] multiprocessorCount Number of SMs in the device.
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[out] barrier Barrier tensor. * \param[out] barrier Barrier tensor.
* \param[in,out] amax AMAX value of the output tensor.
* \param[out] scale_inv Inverse of the output's scaling factor.
*/ */
void nvte_layernorm_fwd(const NVTETensor x, void nvte_layernorm_fwd(const NVTETensor x,
const NVTETensor gamma, const NVTETensor gamma,
const NVTETensor beta, const NVTETensor beta,
const NVTETensor scale,
const float epsilon, const float epsilon,
NVTETensor z, NVTETensor z,
NVTETensor mu, NVTETensor mu,
...@@ -51,9 +47,7 @@ void nvte_layernorm_fwd(const NVTETensor x, ...@@ -51,9 +47,7 @@ void nvte_layernorm_fwd(const NVTETensor x,
cudaStream_t stream, cudaStream_t stream,
const int multiprocessorCount, const int multiprocessorCount,
NVTETensor workspace, NVTETensor workspace,
NVTETensor barrier, NVTETensor barrier);
NVTETensor amax,
NVTETensor scale_inv);
/*! \brief Compute backward of LayerNorm. /*! \brief Compute backward of LayerNorm.
......
...@@ -56,15 +56,21 @@ typedef void* NVTETensor; ...@@ -56,15 +56,21 @@ typedef void* NVTETensor;
* TE tensors are just wrappers on top of raw data and do not * TE tensors are just wrappers on top of raw data and do not
* own memory. * own memory.
* *
* \param[in] dptr Pointer to the tensor data. * \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor. * \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor. * \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
* *
* \return A new TE tensor. * \return A new TE tensor.
*/ */
NVTETensor nvte_create_tensor(void *dptr, NVTETensor nvte_create_tensor(void *dptr,
const NVTEShape shape, const NVTEShape shape,
const NVTEDType dtype); const NVTEDType dtype,
float *amax_dptr,
float *scale_dptr,
float *scale_inv_dptr);
/*! \brief Destroy a TE tensor. /*! \brief Destroy a TE tensor.
* *
...@@ -99,6 +105,30 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor); ...@@ -99,6 +105,30 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor);
*/ */
void *nvte_tensor_data(const NVTETensor tensor); void *nvte_tensor_data(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's amax data.
*
* \param[in] tensor Tensor.
*
* \return A pointer to tensor's amax data.
*/
float *nvte_tensor_amax(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's scale data.
*
* \param[in] tensor Tensor.
*
* \return A pointer to tensor's scale data.
*/
float *nvte_tensor_scale(const NVTETensor tensor);
/*! \brief Get a pointer to the tensor's inverse of scale data.
*
* \param[in] tensor Tensor.
*
* \return A pointer to tensor's inverse of scale data.
*/
float *nvte_tensor_scale_inv(const NVTETensor tensor);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
...@@ -138,9 +168,15 @@ class TensorWrapper { ...@@ -138,9 +168,15 @@ class TensorWrapper {
* \param[in] dptr Pointer to the tensor data. * \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor. * \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor. * \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*/ */
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype) : TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype,
tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype))) {} float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr) :
tensor_(nvte_create_tensor(dptr, shape, static_cast<NVTEDType>(dtype),
amax_dptr, scale_dptr, scale_inv_dptr)) {}
/*! \brief Constructs new TensorWrapper. /*! \brief Constructs new TensorWrapper.
* *
...@@ -151,9 +187,15 @@ class TensorWrapper { ...@@ -151,9 +187,15 @@ class TensorWrapper {
* \param[in] dptr Pointer to the tensor data. * \param[in] dptr Pointer to the tensor data.
* \param[in] shape Shape of the tensor. * \param[in] shape Shape of the tensor.
* \param[in] dtype Data type of the tensor. * \param[in] dtype Data type of the tensor.
* \param[in] amax_dptr Pointer to the AMAX value.
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
*/ */
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype) : TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype,
TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype) {} float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr) :
TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype,
amax_dptr, scale_dptr, scale_inv_dptr) {}
/*! \brief Constructs new empty TensorWrapper. /*! \brief Constructs new empty TensorWrapper.
* *
...@@ -229,6 +271,33 @@ class TensorWrapper { ...@@ -229,6 +271,33 @@ class TensorWrapper {
return nvte_tensor_data(tensor_); return nvte_tensor_data(tensor_);
} }
/*! \brief Get a pointer to the tensor's amax data.
*
* \return A pointer to tensor's amax data.
*/
float *amax() const noexcept {
if (tensor_ == nullptr) return nullptr;
return nvte_tensor_amax(tensor_);
}
/*! \brief Get a pointer to the tensor's scale data.
*
* \return A pointer to tensor's scale data.
*/
float *scale() const noexcept {
if (tensor_ == nullptr) return nullptr;
return nvte_tensor_scale(tensor_);
}
/*! \brief Get a pointer to the tensor's inverse of scale data.
*
* \return A pointer to tensor's inverse of scale data.
*/
float *scale_inv() const noexcept {
if (tensor_ == nullptr) return nullptr;
return nvte_tensor_scale_inv(tensor_);
}
private: private:
/*! \brief Wrapped NVTETensor. */ /*! \brief Wrapped NVTETensor. */
NVTETensor tensor_ = nullptr; NVTETensor tensor_ = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment