Unverified Commit 7f270330 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove intermediate dispatch functions (#56)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f22929cc
......@@ -172,7 +172,6 @@ TEST_P(MultiCastTransposeTestSuite, TestMultiCastTranspose) {
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MultiCastTransposeTestSuite,
......
......@@ -124,317 +124,3 @@ at::Tensor allocateTorchTensor(int M,
return at::empty({static_cast<int64_t>(M)},
at::CUDA(GetATenDType(dtype)));
}
void dispatch_layernorm(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gamma, // i
const std::vector<size_t>& gamma_shape,
const transformer_engine::DType gamma_type,
void* beta, // i
const std::vector<size_t>& beta_shape,
const transformer_engine::DType beta_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
const float epsilon, // i
void* z, // o
const std::vector<size_t>& z_shape,
const transformer_engine::DType z_type,
void* mu, // o
const std::vector<size_t>& mu_shape,
const transformer_engine::DType mu_type,
void* rsigma, // o
const std::vector<size_t>& rsigma_shape,
const transformer_engine::DType rsigma_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type);
auto beta_cu = makeTransformerEngineTensor(beta, beta_shape, beta_type);
auto z_cu = makeTransformerEngineTensor(z, z_shape, z_type, amax, scale, scale_inv);
auto mu_cu = makeTransformerEngineTensor(mu, mu_shape, mu_type);
auto rsigma_cu = makeTransformerEngineTensor(rsigma, rsigma_shape, rsigma_type);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
epsilon, z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
epsilon, z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data());
}
void dispatch_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output_cast, // o
const std::vector<size_t>& output_cast_shape,
const transformer_engine::DType output_cast_type,
void* output_transpose, // o
const std::vector<size_t>& output_transpose_shape,
const transformer_engine::DType output_transpose_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cast_cu = makeTransformerEngineTensor(output_cast, output_cast_shape,
output_cast_type, amax, scale,
scale_inv);
auto output_transpose_cu = makeTransformerEngineTensor(output_transpose, output_transpose_shape,
output_transpose_type, amax,
scale, scale_inv);
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
void dispatch_gelu(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type,
amax, scale, scale_inv);
nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void dispatch_transpose(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void dispatch_bgrad_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type, amax, scale,
scale_inv);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape,
transposed_output_type,
amax, scale, scale_inv);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
}
void dispatch_bgrad_dgelu_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gelu_input, // i
const std::vector<size_t>& gelu_input_shape,
const transformer_engine::DType gelu_input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input, gelu_input_shape,
gelu_input_type);
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type, amax, scale,
scale_inv);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape,
transposed_output_type,
amax, scale, scale_inv);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
}
void dispatch_multi_cast_transpose(
std::vector<void*> input_dptr_list, // i
const std::vector<std::vector<size_t>>& input_shape_list,
const std::vector<transformer_engine::DType>& input_type_list,
std::vector<void*> scale_dptr_list, // i
const std::vector<std::vector<size_t>>& scale_shape_list,
const std::vector<transformer_engine::DType>& scale_type_list,
std::vector<void*> cast_output_dptr_list, // o
const std::vector<std::vector<size_t>>& cast_output_shape_list,
const std::vector<transformer_engine::DType>& cast_output_type_list,
std::vector<void*> transposed_output_dptr_list, // o
const std::vector<std::vector<size_t>>& transposed_output_shape_list,
const std::vector<transformer_engine::DType>& transposed_output_type_list,
std::vector<void*> amax_dptr_list, // o
const std::vector<std::vector<size_t>>& amax_shape_list,
const std::vector<transformer_engine::DType>& amax_type_list,
std::vector<void*> scale_inv_dptr_list, // o
const std::vector<std::vector<size_t>>& scale_inv_shape_list,
const std::vector<transformer_engine::DType>& scale_inv_type_list
) {
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> input_list,
cast_output_list, transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i],
nullptr,
nullptr,
nullptr));
cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
}
// Check tensor lists
NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match");
// Launch TE kernel
nvte_multi_cast_transpose(input_list.size(),
input_list.data(),
cast_output_list.data(),
transposed_output_list.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -153,158 +153,4 @@ at::Tensor allocateTorchTensor(int M,
);
void dispatch_layernorm(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gamma, // i
const std::vector<size_t>& gamma_shape,
const transformer_engine::DType gamma_type,
void* beta, // i
const std::vector<size_t>& beta_shape,
const transformer_engine::DType beta_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
const float epsilon, // i
void* z, // o
const std::vector<size_t>& z_shape,
const transformer_engine::DType z_type,
void* mu, // o
const std::vector<size_t>& mu_shape,
const transformer_engine::DType mu_type,
void* rsigma, // o
const std::vector<size_t>& rsigma_shape,
const transformer_engine::DType rsigma_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount
);
void dispatch_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output_cast, // o
const std::vector<size_t>& output_cast_shape,
const transformer_engine::DType output_cast_type,
void* output_transpose, // o
const std::vector<size_t>& output_transpose_shape,
const transformer_engine::DType output_transpose_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_gelu(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_transpose(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type
);
void dispatch_bgrad_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_bgrad_dgelu_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gelu_input, // i
const std::vector<size_t>& gelu_input_shape,
const transformer_engine::DType gelu_input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_multi_cast_transpose(
std::vector<void*> input_dptr_list, // i
const std::vector<std::vector<size_t>>& input_shape_list,
const std::vector<transformer_engine::DType>& input_type_list,
std::vector<void*> scale_dptr_list, // i
const std::vector<std::vector<size_t>>& scale_shape_list,
const std::vector<transformer_engine::DType>& scale_type_list,
std::vector<void*> cast_output_dptr_list, // o
const std::vector<std::vector<size_t>>& cast_output_shape_list,
const std::vector<transformer_engine::DType>& cast_output_type_list,
std::vector<void*> transposed_output_dptr_list, // o
const std::vector<std::vector<size_t>>& transposed_output_shape_list,
const std::vector<transformer_engine::DType>& transposed_output_type_list,
std::vector<void*> amax_dptr_list, // o
const std::vector<std::vector<size_t>>& amax_shape_list,
const std::vector<transformer_engine::DType>& amax_type_list,
std::vector<void*> scale_inv_dptr_list, // o
const std::vector<std::vector<size_t>>& scale_inv_shape_list,
const std::vector<transformer_engine::DType>& scale_inv_type_list
);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
......@@ -83,15 +83,16 @@ void fused_cast_transpose(at::Tensor input,
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
DType inp_type = GetTransformerEngineDType(input.scalar_type());
dispatch_cast_transpose_fusion(
input.data_ptr(), {M, N}, inp_type,
scale.data_ptr(), {1}, DType::kFloat32,
input_cast.data_ptr(), {M, N}, otype,
input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32);
auto input_cu = makeTransformerEngineTensor(input);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -117,14 +118,29 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
grad_output.size(0),
DType::kByte);
dispatch_bgrad_cast_transpose_fusion(
grad_output.data_ptr(), {M, N}, grad_output_type,
scale.data_ptr(), {1}, DType::kFloat32,
grad_output_cast.data_ptr(), {M, N}, otype,
grad_output_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), {1}, DType::kFloat32,
grad_bias.data_ptr(), {N}, grad_output_type,
scale_inv.data_ptr(), {1}, DType::kFloat32);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_cast, grad_output_transpose};
}
......@@ -153,15 +169,32 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
grad_output.size(0),
DType::kByte);
dispatch_bgrad_dgelu_cast_transpose_fusion(
grad_output.data_ptr(), {M, N}, grad_output_type,
gelu_input.data_ptr(), {M, N}, grad_output_type,
scale.data_ptr(), {1}, DType::kFloat32,
dgelu.data_ptr(), {M, N}, otype,
dgelu_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), {1}, DType::kFloat32,
grad_bias.data_ptr(), {N}, grad_output_type,
scale_inv.data_ptr(), {1}, DType::kFloat32);
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
return {grad_bias, dgelu, dgelu_transpose};
}
......@@ -234,26 +267,56 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
scale_inv_type_list);
}
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list,
nvte_cast_output_list, nvte_transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
nvte_input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i],
nullptr,
nullptr,
nullptr));
nvte_cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
nvte_transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
}
// Check tensor lists
NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(),
"Number of input and T output tensors must match");
// Launch TE kernel
dispatch_multi_cast_transpose(
input_dptr_list,
input_shape_list,
input_type_list,
scale_dptr_list,
scale_shape_list,
scale_type_list,
cast_output_dptr_list,
cast_output_shape_list,
cast_output_type_list,
transposed_output_dptr_list,
transposed_output_shape_list,
transposed_output_type_list,
amax_dptr_list,
amax_shape_list,
amax_type_list,
scale_inv_dptr_list,
scale_inv_shape_list,
scale_inv_type_list);
nvte_multi_cast_transpose(nvte_input_list.size(),
nvte_input_list.data(),
nvte_cast_output_list.data(),
nvte_transposed_output_list.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -265,14 +328,17 @@ at::Tensor fp8_transpose(at::Tensor input,
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_transpose =
auto output =
allocateTorchTensor(input.size(1),
input.size(0),
DType::kByte);
dispatch_transpose(input.data_ptr(), {M, N}, otype,
input_transpose.data_ptr(), {N, M}, otype);
return input_transpose;
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
......@@ -287,18 +353,17 @@ at::Tensor fp8_gelu(at::Tensor input,
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
DType input_type = GetTransformerEngineDType(input.scalar_type());
auto output =
allocateTorchTensor(input.size(0),
input.size(1),
DType::kByte);
dispatch_gelu(input.data_ptr(), {M, N}, input_type,
scale.data_ptr(), {1}, DType::kFloat32,
output.data_ptr(), {M, N}, otype,
amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32);
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
......@@ -379,19 +444,40 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
dispatch_layernorm(
input.data_ptr(), {N, H}, itype,
weight.data_ptr(), {H}, itype,
bias.data_ptr(), {H}, itype,
scale.data_ptr(), {1}, DType::kFloat32,
eps,
ln_out.data_ptr(), {N, H}, otype,
mu.data_ptr(), {N}, DType::kFloat32,
rsigma.data_ptr(), {N}, DType::kFloat32,
amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
......@@ -429,22 +515,43 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype)));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
dispatch_layernorm(input.data_ptr(), {N, H}, itype,
weight.data_ptr(), {H}, itype,
bias.data_ptr(), {H}, itype,
nullptr, {1}, DType::kFloat32,
eps,
ln_out.data_ptr(), {N, H}, itype,
mu.data_ptr(), {N}, DType::kFloat32,
rsigma.data_ptr(), {N}, DType::kFloat32,
nullptr, {1}, DType::kFloat32,
nullptr, {1}, DType::kFloat32,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount);
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -456,6 +563,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
return out[0];
}
at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor amax,
......@@ -494,7 +602,6 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype,
nullptr, nullptr, scale_inv.data_ptr());
auto output_cu = makeTransformerEngineTensor(output);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32);
nvte_fp8_dequantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment