Unverified Commit 7f270330 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove intermediate dispatch functions (#56)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f22929cc
...@@ -172,7 +172,6 @@ TEST_P(MultiCastTransposeTestSuite, TestMultiCastTranspose) { ...@@ -172,7 +172,6 @@ TEST_P(MultiCastTransposeTestSuite, TestMultiCastTranspose) {
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
OperatorTest, OperatorTest,
MultiCastTransposeTestSuite, MultiCastTransposeTestSuite,
......
...@@ -124,317 +124,3 @@ at::Tensor allocateTorchTensor(int M, ...@@ -124,317 +124,3 @@ at::Tensor allocateTorchTensor(int M,
return at::empty({static_cast<int64_t>(M)}, return at::empty({static_cast<int64_t>(M)},
at::CUDA(GetATenDType(dtype))); at::CUDA(GetATenDType(dtype)));
} }
void dispatch_layernorm(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gamma, // i
const std::vector<size_t>& gamma_shape,
const transformer_engine::DType gamma_type,
void* beta, // i
const std::vector<size_t>& beta_shape,
const transformer_engine::DType beta_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
const float epsilon, // i
void* z, // o
const std::vector<size_t>& z_shape,
const transformer_engine::DType z_type,
void* mu, // o
const std::vector<size_t>& mu_shape,
const transformer_engine::DType mu_type,
void* rsigma, // o
const std::vector<size_t>& rsigma_shape,
const transformer_engine::DType rsigma_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto gamma_cu = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type);
auto beta_cu = makeTransformerEngineTensor(beta, beta_shape, beta_type);
auto z_cu = makeTransformerEngineTensor(z, z_shape, z_type, amax, scale, scale_inv);
auto mu_cu = makeTransformerEngineTensor(mu, mu_shape, mu_type);
auto rsigma_cu = makeTransformerEngineTensor(rsigma, rsigma_shape, rsigma_type);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
epsilon, z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
epsilon, z_cu.data(), mu_cu.data(), rsigma_cu.data(),
at::cuda::getCurrentCUDAStream(), multiProcessorCount,
workspace.data(), barrier.data());
}
void dispatch_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output_cast, // o
const std::vector<size_t>& output_cast_shape,
const transformer_engine::DType output_cast_type,
void* output_transpose, // o
const std::vector<size_t>& output_transpose_shape,
const transformer_engine::DType output_transpose_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cast_cu = makeTransformerEngineTensor(output_cast, output_cast_shape,
output_cast_type, amax, scale,
scale_inv);
auto output_transpose_cu = makeTransformerEngineTensor(output_transpose, output_transpose_shape,
output_transpose_type, amax,
scale, scale_inv);
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
void dispatch_gelu(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type,
amax, scale, scale_inv);
nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void dispatch_transpose(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}
void dispatch_bgrad_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type, amax, scale,
scale_inv);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape,
transposed_output_type,
amax, scale, scale_inv);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
}
void dispatch_bgrad_dgelu_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gelu_input, // i
const std::vector<size_t>& gelu_input_shape,
const transformer_engine::DType gelu_input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
) {
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input, gelu_input_shape,
gelu_input_type);
auto input_cu = makeTransformerEngineTensor(input, input_shape, input_type);
auto cast_output_cu = makeTransformerEngineTensor(cast_output, cast_output_shape,
cast_output_type, amax, scale,
scale_inv);
auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
transposed_output_shape,
transposed_output_type,
amax, scale, scale_inv);
auto dbias_cu = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
}
void dispatch_multi_cast_transpose(
std::vector<void*> input_dptr_list, // i
const std::vector<std::vector<size_t>>& input_shape_list,
const std::vector<transformer_engine::DType>& input_type_list,
std::vector<void*> scale_dptr_list, // i
const std::vector<std::vector<size_t>>& scale_shape_list,
const std::vector<transformer_engine::DType>& scale_type_list,
std::vector<void*> cast_output_dptr_list, // o
const std::vector<std::vector<size_t>>& cast_output_shape_list,
const std::vector<transformer_engine::DType>& cast_output_type_list,
std::vector<void*> transposed_output_dptr_list, // o
const std::vector<std::vector<size_t>>& transposed_output_shape_list,
const std::vector<transformer_engine::DType>& transposed_output_type_list,
std::vector<void*> amax_dptr_list, // o
const std::vector<std::vector<size_t>>& amax_shape_list,
const std::vector<transformer_engine::DType>& amax_type_list,
std::vector<void*> scale_inv_dptr_list, // o
const std::vector<std::vector<size_t>>& scale_inv_shape_list,
const std::vector<transformer_engine::DType>& scale_inv_type_list
) {
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> input_list,
cast_output_list, transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i],
nullptr,
nullptr,
nullptr));
cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
}
// Check tensor lists
NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(transposed_output_list.size() == input_list.size(),
"Number of input and T output tensors must match");
// Launch TE kernel
nvte_multi_cast_transpose(input_list.size(),
input_list.data(),
cast_output_list.data(),
transposed_output_list.data(),
at::cuda::getCurrentCUDAStream());
}
...@@ -153,158 +153,4 @@ at::Tensor allocateTorchTensor(int M, ...@@ -153,158 +153,4 @@ at::Tensor allocateTorchTensor(int M,
); );
void dispatch_layernorm(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gamma, // i
const std::vector<size_t>& gamma_shape,
const transformer_engine::DType gamma_type,
void* beta, // i
const std::vector<size_t>& beta_shape,
const transformer_engine::DType beta_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
const float epsilon, // i
void* z, // o
const std::vector<size_t>& z_shape,
const transformer_engine::DType z_type,
void* mu, // o
const std::vector<size_t>& mu_shape,
const transformer_engine::DType mu_type,
void* rsigma, // o
const std::vector<size_t>& rsigma_shape,
const transformer_engine::DType rsigma_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type,
const int multiProcessorCount
);
void dispatch_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output_cast, // o
const std::vector<size_t>& output_cast_shape,
const transformer_engine::DType output_cast_type,
void* output_transpose, // o
const std::vector<size_t>& output_transpose_shape,
const transformer_engine::DType output_transpose_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_gelu(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_transpose(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* output, // o
const std::vector<size_t>& output_shape,
const transformer_engine::DType output_type
);
void dispatch_bgrad_cast_transpose_fusion(void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_bgrad_dgelu_cast_transpose_fusion(
void* input, // i
const std::vector<size_t>& input_shape,
const transformer_engine::DType input_type,
void* gelu_input, // i
const std::vector<size_t>& gelu_input_shape,
const transformer_engine::DType gelu_input_type,
void* scale, // i
const std::vector<size_t>& scale_shape,
const transformer_engine::DType scale_type,
void* cast_output, // o
const std::vector<size_t>& cast_output_shape,
const transformer_engine::DType cast_output_type,
void* transposed_output, // o
const std::vector<size_t>& transposed_output_shape,
const transformer_engine::DType transposed_output_type,
void* amax, // o
const std::vector<size_t>& amax_shape,
const transformer_engine::DType amax_type,
void* dbias, // o
const std::vector<size_t>& dbias_shape,
const transformer_engine::DType dbias_type,
void* scale_inv, // o
const std::vector<size_t>& scale_inv_shape,
const transformer_engine::DType scale_inv_type
);
void dispatch_multi_cast_transpose(
std::vector<void*> input_dptr_list, // i
const std::vector<std::vector<size_t>>& input_shape_list,
const std::vector<transformer_engine::DType>& input_type_list,
std::vector<void*> scale_dptr_list, // i
const std::vector<std::vector<size_t>>& scale_shape_list,
const std::vector<transformer_engine::DType>& scale_type_list,
std::vector<void*> cast_output_dptr_list, // o
const std::vector<std::vector<size_t>>& cast_output_shape_list,
const std::vector<transformer_engine::DType>& cast_output_type_list,
std::vector<void*> transposed_output_dptr_list, // o
const std::vector<std::vector<size_t>>& transposed_output_shape_list,
const std::vector<transformer_engine::DType>& transposed_output_type_list,
std::vector<void*> amax_dptr_list, // o
const std::vector<std::vector<size_t>>& amax_shape_list,
const std::vector<transformer_engine::DType>& amax_type_list,
std::vector<void*> scale_inv_dptr_list, // o
const std::vector<std::vector<size_t>>& scale_inv_shape_list,
const std::vector<transformer_engine::DType>& scale_inv_type_list
);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
...@@ -83,15 +83,16 @@ void fused_cast_transpose(at::Tensor input, ...@@ -83,15 +83,16 @@ void fused_cast_transpose(at::Tensor input,
size_t M = static_cast<size_t>(input.size(0)); size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1)); size_t N = static_cast<size_t>(input.size(1));
DType inp_type = GetTransformerEngineDType(input.scalar_type()); auto input_cu = makeTransformerEngineTensor(input);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
dispatch_cast_transpose_fusion( amax.data_ptr(), scale.data_ptr(),
input.data_ptr(), {M, N}, inp_type, scale_inv.data_ptr());
scale.data_ptr(), {1}, DType::kFloat32, auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
input_cast.data_ptr(), {M, N}, otype, amax.data_ptr(), scale.data_ptr(),
input_transpose.data_ptr(), {N, M}, otype, scale_inv.data_ptr());
amax.data_ptr(), {1}, DType::kFloat32,
scale_inv.data_ptr(), {1}, DType::kFloat32); nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
} }
...@@ -117,14 +118,29 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, ...@@ -117,14 +118,29 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
grad_output.size(0), grad_output.size(0),
DType::kByte); DType::kByte);
dispatch_bgrad_cast_transpose_fusion( auto input_cu = makeTransformerEngineTensor(grad_output);
grad_output.data_ptr(), {M, N}, grad_output_type, auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
scale.data_ptr(), {1}, DType::kFloat32, otype, amax.data_ptr(), scale.data_ptr(),
grad_output_cast.data_ptr(), {M, N}, otype, scale_inv.data_ptr());
grad_output_transpose.data_ptr(), {N, M}, otype, auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
amax.data_ptr(), {1}, DType::kFloat32, {N, M}, otype, amax.data_ptr(),
grad_bias.data_ptr(), {N}, grad_output_type, scale.data_ptr(), scale_inv.data_ptr());
scale_inv.data_ptr(), {1}, DType::kFloat32); auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_cast, grad_output_transpose}; return {grad_bias, grad_output_cast, grad_output_transpose};
} }
...@@ -153,15 +169,32 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, ...@@ -153,15 +169,32 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
grad_output.size(0), grad_output.size(0),
DType::kByte); DType::kByte);
dispatch_bgrad_dgelu_cast_transpose_fusion( transformer_engine::TensorWrapper workspace;
grad_output.data_ptr(), {M, N}, grad_output_type, auto gelu_input_cu = makeTransformerEngineTensor(gelu_input);
gelu_input.data_ptr(), {M, N}, grad_output_type, auto input_cu = makeTransformerEngineTensor(grad_output);
scale.data_ptr(), {1}, DType::kFloat32, auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N},
dgelu.data_ptr(), {M, N}, otype, otype, amax.data_ptr(), scale.data_ptr(),
dgelu_transpose.data_ptr(), {N, M}, otype, scale_inv.data_ptr());
amax.data_ptr(), {1}, DType::kFloat32, auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M},
grad_bias.data_ptr(), {N}, grad_output_type, otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr(), {1}, DType::kFloat32); scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
return {grad_bias, dgelu, dgelu_transpose}; return {grad_bias, dgelu, dgelu_transpose};
} }
...@@ -234,26 +267,56 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list, ...@@ -234,26 +267,56 @@ void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
scale_inv_type_list); scale_inv_type_list);
} }
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list,
nvte_cast_output_list, nvte_transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
nvte_input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i],
nullptr,
nullptr,
nullptr));
nvte_cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
nvte_transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
}
// Check tensor lists
NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(),
"Number of input and T output tensors must match");
// Launch TE kernel // Launch TE kernel
dispatch_multi_cast_transpose( nvte_multi_cast_transpose(nvte_input_list.size(),
input_dptr_list, nvte_input_list.data(),
input_shape_list, nvte_cast_output_list.data(),
input_type_list, nvte_transposed_output_list.data(),
scale_dptr_list, at::cuda::getCurrentCUDAStream());
scale_shape_list,
scale_type_list,
cast_output_dptr_list,
cast_output_shape_list,
cast_output_type_list,
transposed_output_dptr_list,
transposed_output_shape_list,
transposed_output_type_list,
amax_dptr_list,
amax_shape_list,
amax_type_list,
scale_inv_dptr_list,
scale_inv_shape_list,
scale_inv_type_list);
} }
...@@ -265,14 +328,17 @@ at::Tensor fp8_transpose(at::Tensor input, ...@@ -265,14 +328,17 @@ at::Tensor fp8_transpose(at::Tensor input,
size_t M = static_cast<size_t>(input.size(0)); size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1)); size_t N = static_cast<size_t>(input.size(1));
auto input_transpose = auto output =
allocateTorchTensor(input.size(1), allocateTorchTensor(input.size(1),
input.size(0), input.size(0),
DType::kByte); DType::kByte);
dispatch_transpose(input.data_ptr(), {M, N}, otype,
input_transpose.data_ptr(), {N, M}, otype);
return input_transpose; auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
} }
...@@ -287,18 +353,17 @@ at::Tensor fp8_gelu(at::Tensor input, ...@@ -287,18 +353,17 @@ at::Tensor fp8_gelu(at::Tensor input,
size_t M = static_cast<size_t>(input.size(0)); size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1)); size_t N = static_cast<size_t>(input.size(1));
DType input_type = GetTransformerEngineDType(input.scalar_type());
auto output = auto output =
allocateTorchTensor(input.size(0), allocateTorchTensor(input.size(0),
input.size(1), input.size(1),
DType::kByte); DType::kByte);
dispatch_gelu(input.data_ptr(), {M, N}, input_type, auto input_cu = makeTransformerEngineTensor(input);
scale.data_ptr(), {1}, DType::kFloat32, auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
output.data_ptr(), {M, N}, otype, amax.data_ptr(), scale.data_ptr(),
amax.data_ptr(), {1}, DType::kFloat32, scale_inv.data_ptr());
scale_inv.data_ptr(), {1}, DType::kFloat32);
nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output; return output;
} }
...@@ -379,19 +444,40 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -379,19 +444,40 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
dispatch_layernorm( // This call populates workspace and barrier tensors with the required config
input.data_ptr(), {N, H}, itype, nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
weight.data_ptr(), {H}, itype, mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
bias.data_ptr(), {H}, itype, at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
scale.data_ptr(), {1}, DType::kFloat32, workspace.data(), barrier.data());
eps,
ln_out.data_ptr(), {N, H}, otype, // Fill workspace and barrier
mu.data_ptr(), {N}, DType::kFloat32, auto workspace_data = allocateSpace(workspace.shape(),
rsigma.data_ptr(), {N}, DType::kFloat32, workspace.dtype());
amax.data_ptr(), {1}, DType::kFloat32, auto barrier_data = allocateSpace(barrier.shape(),
scale_inv.data_ptr(), {1}, DType::kFloat32, barrier.dtype(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount); true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma}; return {ln_out, mu, rsigma};
} }
...@@ -429,22 +515,43 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -429,22 +515,43 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype)));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
dispatch_layernorm(input.data_ptr(), {N, H}, itype, // Fill workspace and barrier
weight.data_ptr(), {H}, itype, auto workspace_data = allocateSpace(workspace.shape(),
bias.data_ptr(), {H}, itype, workspace.dtype());
nullptr, {1}, DType::kFloat32, auto barrier_data = allocateSpace(barrier.shape(),
eps, barrier.dtype(),
ln_out.data_ptr(), {N, H}, itype, true);
mu.data_ptr(), {N}, DType::kFloat32, workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
rsigma.data_ptr(), {N}, DType::kFloat32, workspace.shape(),
nullptr, {1}, DType::kFloat32, workspace.dtype());
nullptr, {1}, DType::kFloat32, barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount); barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma}; return {ln_out, mu, rsigma};
} }
at::Tensor layernorm_fwd_inf(const at::Tensor &input, at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -456,6 +563,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, ...@@ -456,6 +563,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
return out[0]; return out[0];
} }
at::Tensor cast_to_fp8(const at::Tensor &input, at::Tensor cast_to_fp8(const at::Tensor &input,
const at::Tensor &scale, const at::Tensor &scale,
at::Tensor amax, at::Tensor amax,
...@@ -494,7 +602,6 @@ at::Tensor cast_from_fp8(const at::Tensor &input, ...@@ -494,7 +602,6 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype, auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {N, H}, itype,
nullptr, nullptr, scale_inv.data_ptr()); nullptr, nullptr, scale_inv.data_ptr());
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
auto scale_inv_cu = makeTransformerEngineTensor(scale_inv.data_ptr(), {1}, DType::kFloat32);
nvte_fp8_dequantize(input_cu.data(), output_cu.data(), nvte_fp8_dequantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment