Commit 49c97411 authored by Rick Ho's avatar Rick Ho
Browse files

support fp16

parent 952e3135
#ifndef CUBLAS_WRAPPER_H #ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H #define CUBLAS_WRAPPER_H
#include <cublas_v2.h> #include <cublas_v2.h>
#include <c10/util/Half.h>
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transa,
...@@ -74,5 +75,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle, ...@@ -74,5 +75,21 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
__half *C, int ldc) { __half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
} }
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const c10::Half *alpha,
const c10::Half *A, int lda,
const c10::Half *B, int ldb,
const c10::Half *beta,
c10::Half *C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k,
(const __half*)alpha,
(const __half*)A, lda,
(const __half*)B, ldb,
(const __half*)beta,
(__half*)C, ldc);
}
#endif // CUBLAS_WRAPPER_H #endif // CUBLAS_WRAPPER_H
...@@ -112,7 +112,7 @@ std::vector<torch::Tensor> moe_cuda_global_scatter( ...@@ -112,7 +112,7 @@ std::vector<torch::Tensor> moe_cuda_global_scatter(
auto global_input_buf = input_buf.new_empty({batch_size, in_feat}); auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_scatter", ([&] { "moe_cuda_global_scatter", ([&] {
moe_cuda_global_scatter_impl<scalar_t>( moe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
...@@ -182,7 +182,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather( ...@@ -182,7 +182,7 @@ std::vector<torch::Tensor> moe_cuda_global_gather(
auto local_output_buf = output_buf.new_empty({batch_size, out_feat}); auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index()); auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(),
"moe_cuda_global_gather", ([&] { "moe_cuda_global_gather", ([&] {
moe_cuda_global_gather_impl<scalar_t>( moe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
......
...@@ -233,7 +233,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter( ...@@ -233,7 +233,7 @@ std::vector<torch::Tensor> moe_cuda_local_scatter(
auto input_buf = torch::empty_like(input); auto input_buf = torch::empty_like(input);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_local_scatter_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "moe_local_scatter_cuda",
([&] { ([&] {
moe_cuda_local_scatter_impl<scalar_t>( moe_cuda_local_scatter_impl<scalar_t>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
...@@ -255,7 +255,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather( ...@@ -255,7 +255,7 @@ std::vector<torch::Tensor> moe_cuda_local_gather(
auto output = torch::empty_like(output_buf); auto output = torch::empty_like(output_buf);
AT_DISPATCH_FLOATING_TYPES(output_buf.scalar_type(), "moe_local_gather_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "moe_local_gather_cuda",
([&] { ([&] {
moe_cuda_local_gather_impl<scalar_t>( moe_cuda_local_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
...@@ -288,7 +288,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -288,7 +288,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
.dtype(input_buf.dtype()); .dtype(input_buf.dtype());
auto output = torch::empty({batch_size, out_feat}, out_options); auto output = torch::empty({batch_size, out_feat}, out_options);
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_forward_cuda", AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda",
([&] { ([&] {
moe_cuda_forward_impl<scalar_t>( moe_cuda_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
...@@ -326,7 +326,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -326,7 +326,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat}); auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat}); auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), "moe_cuda_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_backward_impl<scalar_t>( moe_cuda_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(), grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
......
...@@ -127,7 +127,7 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward( ...@@ -127,7 +127,7 @@ std::vector<torch::Tensor> moe_cuda_global_fused_forward(
auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat}); auto global_input_buf = input_buf.new_empty({global_batch_size, in_feat});
auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat}); auto global_output_buf = input_buf.new_empty({global_batch_size, out_feat});
auto output_buf = input_buf.new_empty({local_batch_size, out_feat}); auto output_buf = input_buf.new_empty({local_batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(),
"moe_cuda_global_fused_forward", ([&] { "moe_cuda_global_fused_forward", ([&] {
moe_cuda_global_fused_forward_impl( moe_cuda_global_fused_forward_impl(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment