Commit 0bb60881 authored by Rick Ho's avatar Rick Ho
Browse files

bf16 support with test

parent c9ccc0eb
...@@ -52,8 +52,8 @@ void _reduce_grad( ...@@ -52,8 +52,8 @@ void _reduce_grad(
cudaEventDestroy(evt_stash); cudaEventDestroy(evt_stash);
auto dtype = getNcclDataType(t.scalar_type()); auto dtype = getNcclDataType(t.scalar_type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(t.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
"fmoe_cuda_reduce_grad", ([&] { t.scalar_type(), "fmoe_cuda_reduce_grad", ([&] {
void* buf = (void*)t.data_ptr<scalar_t>(); void* buf = (void*)t.data_ptr<scalar_t>();
NCCL_SAFE_CALL(ncclReduce(buf, buf, expert_size, NCCL_SAFE_CALL(ncclReduce(buf, buf, expert_size,
dtype, dtype,
...@@ -110,8 +110,8 @@ std::vector<torch::Tensor> _smart_sch_forward( ...@@ -110,8 +110,8 @@ std::vector<torch::Tensor> _smart_sch_forward(
} }
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
"fmoe_cuda_smart_sch_forward", ([&] { input_buf.scalar_type(), "fmoe_cuda_smart_sch_forward", ([&] {
fmoe_cuda_fused_forward_impl( fmoe_cuda_fused_forward_impl(
forward_fn, forward_fn,
stash_fn, stash_fn,
......
...@@ -58,8 +58,8 @@ torch::Tensor _global_scatter( ...@@ -58,8 +58,8 @@ torch::Tensor _global_scatter(
auto global_input_buf = input_buf.new_empty({batch_size, in_feat}); auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index()); auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
"fmoe_cuda_global_scatter", ([&] { input_buf.scalar_type(), "fmoe_cuda_global_scatter", ([&] {
fmoe_cuda_global_scatter_impl<scalar_t>( fmoe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(), local_expert_count.data_ptr<long>(),
...@@ -84,8 +84,8 @@ torch::Tensor _global_gather( ...@@ -84,8 +84,8 @@ torch::Tensor _global_gather(
auto local_output_buf = output_buf.new_empty({batch_size, out_feat}); auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index()); auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
"fmoe_cuda_global_gather", ([&] { output_buf.scalar_type(), "fmoe_cuda_global_gather", ([&] {
fmoe_cuda_global_gather_impl<scalar_t>( fmoe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(), output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(), local_expert_count.data_ptr<long>(),
......
...@@ -30,7 +30,8 @@ torch::Tensor _linear_forward( ...@@ -30,7 +30,8 @@ torch::Tensor _linear_forward(
output = torch::empty({batch_size, out_feat}, out_options); output = torch::empty({batch_size, out_feat}, out_options);
} }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_forward_cuda", AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
input_buf.scalar_type(), "moe_forward_cuda",
([&] { ([&] {
fmoe_cuda_linear_forward_impl<scalar_t>( fmoe_cuda_linear_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
...@@ -72,7 +73,8 @@ std::vector<torch::Tensor> _linear_backward( ...@@ -72,7 +73,8 @@ std::vector<torch::Tensor> _linear_backward(
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat}); auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat}); auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "moe_cuda_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
input_buf.scalar_type(), "moe_cuda_backward", ([&] {
fmoe_cuda_linear_backward_impl<scalar_t>( fmoe_cuda_linear_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(), grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(), input_buf.data_ptr<scalar_t>(),
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CUBLAS_WRAPPER_H #define CUBLAS_WRAPPER_H
#include <cublas_v2.h> #include <cublas_v2.h>
#include <c10/util/Half.h> #include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transa,
...@@ -108,5 +109,26 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle, ...@@ -108,5 +109,26 @@ inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
(__half*)C, ldc); (__half*)C, ldc);
#endif #endif
} }
inline cublasStatus_t cublasXgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const c10::BFloat16 *alpha,
const c10::BFloat16 *A, int lda,
const c10::BFloat16 *B, int ldb,
const c10::BFloat16 *beta,
c10::BFloat16 *C, int ldc) {
#ifdef FMOE_USE_HIP
// TODO: Support bf16 for HIP
assert(false);
#else
return cublasSgemmEx(handle, transa, transb, m, n, k,
(const float*)alpha,
(const void*)A, CUDA_R_16F, lda,
(const void*)B, CUDA_R_16F, ldb,
(const float*)beta,
(void*)C, CUDA_R_16F, ldc);
#endif
}
#endif // CUBLAS_WRAPPER_H #endif // CUBLAS_WRAPPER_H
...@@ -41,7 +41,7 @@ else: ...@@ -41,7 +41,7 @@ else:
if __name__ == '__main__': if __name__ == '__main__':
setuptools.setup( setuptools.setup(
name='fastmoe', name='fastmoe',
version='1.0.1', version='1.0.2',
description='An efficient Mixture-of-Experts system for PyTorch', description='An efficient Mixture-of-Experts system for PyTorch',
author=', '.join(authors), author=', '.join(authors),
author_email='hja20@mails.tsinghua.edu.cn', author_email='hja20@mails.tsinghua.edu.cn',
......
...@@ -80,7 +80,7 @@ class NaiveExpert(nn.Module): ...@@ -80,7 +80,7 @@ class NaiveExpert(nn.Module):
super(NaiveExpert, self).__init__() super(NaiveExpert, self).__init__()
self.linear = nn.Linear(d_model, d_model).cuda() self.linear = nn.Linear(d_model, d_model).cuda()
def forward(self, x): def forward(self, x, fec=None):
return self.linear(x) return self.linear(x)
...@@ -91,5 +91,5 @@ class LinearExpert(nn.Module): ...@@ -91,5 +91,5 @@ class LinearExpert(nn.Module):
nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model), nn.Linear(d_model, d_model * 2), nn.ReLU(), nn.Linear(d_model * 2, d_model),
).cuda() ).cuda()
def forward(self, x): def forward(self, x, fec=None):
return self.model(x) return self.model(x)
...@@ -108,7 +108,7 @@ class MyMoE(FMoE): ...@@ -108,7 +108,7 @@ class MyMoE(FMoE):
@pytest.mark.parametrize("dp_group", [None]) @pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None]) @pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"data_type", ["torch.FloatTensor", "torch.DoubleTensor", "torch.HalfTensor"] "data_type", [torch.float32, torch.float16, torch.bfloat16, torch.double]
) )
@pytest.mark.parametrize("list_input", [False, True]) @pytest.mark.parametrize("list_input", [False, True])
def test_fmoe_mimo_linear( def test_fmoe_mimo_linear(
...@@ -138,9 +138,9 @@ def test_fmoe_mimo_linear( ...@@ -138,9 +138,9 @@ def test_fmoe_mimo_linear(
mp_group=mp_group, mp_group=mp_group,
top_k=top_k, top_k=top_k,
activation=activation, activation=activation,
).cuda() ).cuda().to(data_type)
x = torch.rand(batch_size, d_model).cuda() x = torch.rand(batch_size, d_model).cuda().to(data_type)
inp = [x, x.clone()] if list_input else {"x": x, "y": x.clone()} inp = [x, x.clone()] if list_input else {"x": x, "y": x.clone()}
moe_out = moe(inp) moe_out = moe(inp)
...@@ -162,6 +162,6 @@ if __name__ == "__main__": ...@@ -162,6 +162,6 @@ if __name__ == "__main__":
mp_group=None, mp_group=None,
dp_group=None, dp_group=None,
world_group=None, world_group=None,
data_type=torch.float32, data_type=torch.bfloat16,
list_input=True list_input=True
) )
...@@ -51,6 +51,8 @@ def _perform_forward( ...@@ -51,6 +51,8 @@ def _perform_forward(
def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3): def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
for name, mo, ro in zip(names, moe_out_list, raw_out_list): for name, mo, ro in zip(names, moe_out_list, raw_out_list):
err = (mo - ro).abs().max() err = (mo - ro).abs().max()
if err.dtype == torch.bfloat16:
precision *= 100
print("Rank {} {} abs err {}".format(rank, name, err)) print("Rank {} {} abs err {}".format(rank, name, err))
if err > precision: if err > precision:
sys.stderr.write(f"=========== {name} moe out ==============\n") sys.stderr.write(f"=========== {name} moe out ==============\n")
...@@ -217,6 +219,7 @@ def test_fmoe_linear( ...@@ -217,6 +219,7 @@ def test_fmoe_linear(
@pytest.mark.parametrize("mp_group", [None]) @pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None]) @pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None]) @pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16, torch.bfloat16])
def test_fmoe( def test_fmoe(
batch_size, batch_size,
num_expert, num_expert,
...@@ -228,6 +231,7 @@ def test_fmoe( ...@@ -228,6 +231,7 @@ def test_fmoe(
mp_group, mp_group,
dp_group, dp_group,
world_group, world_group,
data_type
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
...@@ -243,7 +247,7 @@ def test_fmoe( ...@@ -243,7 +247,7 @@ def test_fmoe(
mp_group=mp_group, mp_group=mp_group,
expert=expert, expert=expert,
top_k=top_k, top_k=top_k,
).cuda() ).cuda().to(data_type)
moe_raw = BruteForceMoE( moe_raw = BruteForceMoE(
expert=expert, expert=expert,
...@@ -251,7 +255,7 @@ def test_fmoe( ...@@ -251,7 +255,7 @@ def test_fmoe(
d_model=d_model, d_model=d_model,
world_size=world_size, world_size=world_size,
top_k=top_k, top_k=top_k,
).cuda() ).cuda().to(data_type)
if world_size == 1: if world_size == 1:
for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts): for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
...@@ -275,7 +279,7 @@ def test_fmoe( ...@@ -275,7 +279,7 @@ def test_fmoe(
].data = para_tensor_gathered[expertID] ].data = para_tensor_gathered[expertID]
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward( moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type
) )
def get_experts_grad(experts: List[nn.Module]): def get_experts_grad(experts: List[nn.Module]):
...@@ -396,6 +400,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group): ...@@ -396,6 +400,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
@pytest.mark.parametrize("mp_group", [None]) @pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None]) @pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None]) @pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize("data_type", [torch.float32])
def test_fmoe_experts( def test_fmoe_experts(
batch_size, batch_size,
num_expert, num_expert,
...@@ -407,6 +412,7 @@ def test_fmoe_experts( ...@@ -407,6 +412,7 @@ def test_fmoe_experts(
mp_group, mp_group,
dp_group, dp_group,
world_group, world_group,
data_type
): ):
torch.manual_seed(42 + rank) torch.manual_seed(42 + rank)
torch.cuda.manual_seed(42 + rank) torch.cuda.manual_seed(42 + rank)
...@@ -422,7 +428,7 @@ def test_fmoe_experts( ...@@ -422,7 +428,7 @@ def test_fmoe_experts(
mp_group=mp_group, mp_group=mp_group,
expert=expert, expert=expert,
top_k=top_k, top_k=top_k,
).cuda() ).cuda().to(data_type)
moe_raw = BruteForceMoE( moe_raw = BruteForceMoE(
expert=expert, expert=expert,
...@@ -430,7 +436,7 @@ def test_fmoe_experts( ...@@ -430,7 +436,7 @@ def test_fmoe_experts(
d_model=d_model, d_model=d_model,
world_size=world_size, world_size=world_size,
top_k=top_k, top_k=top_k,
).cuda() ).cuda().to(data_type)
if world_size == 1: if world_size == 1:
for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts): for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
...@@ -454,7 +460,7 @@ def test_fmoe_experts( ...@@ -454,7 +460,7 @@ def test_fmoe_experts(
].data = para_tensor_gathered[expertID] ].data = para_tensor_gathered[expertID]
moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward( moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type
) )
def get_experts_grad(experts: List[nn.Module]): def get_experts_grad(experts: List[nn.Module]):
...@@ -488,16 +494,16 @@ def test_fmoe_experts( ...@@ -488,16 +494,16 @@ def test_fmoe_experts(
if __name__ == "__main__": if __name__ == "__main__":
test_fmoe_linear( test_fmoe(
batch_size=2, batch_size=2,
num_expert=2, num_expert=2,
d_model=2, d_model=2,
top_k=2, top_k=2,
d_hidden=16, expert=[NaiveExpert for _ in range(4)],
rank=0, rank=0,
world_size=1, world_size=1,
mp_group=None, mp_group=None,
dp_group=None, dp_group=None,
world_group=None, world_group=None,
data_type=torch.float32, data_type=torch.bfloat16
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment