Commit 35addec6 authored by Rick Ho's avatar Rick Ho
Browse files

two-level matmul fix transpose

parent 191c1e46
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input, torch::Tensor input,
torch::Tensor gate, torch::Tensor gate,
torch::Tensor weight); torch::Tensor weight1,
torch::Tensor weight2);
std::vector<torch::Tensor> moe_cuda_backward( std::vector<torch::Tensor> moe_cuda_backward(
torch::Tensor grad_output, torch::Tensor grad_output,
...@@ -25,17 +26,19 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -25,17 +26,19 @@ std::vector<torch::Tensor> moe_cuda_backward(
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input, // [batch_size x in_feat] torch::Tensor input, // [batch_size x in_feat]
torch::Tensor gate, // [batch_size] torch::Tensor gate, // [batch_size]
torch::Tensor weight // [num_expert x out_feat x in_feat] torch::Tensor weight1, // [num_expert x hidden_feat x in_feat]
torch::Tensor weight2 // [num_expert x out_feat x hidden_feat]
) { ) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gate); CHECK_INPUT(gate);
CHECK_INPUT(weight); CHECK_INPUT(weight1);
CHECK_INPUT(weight2);
/* /*
The bias term should have been merged into weight. Note the following fact that The bias term should have been merged into weight. Note the following fact that
Wx+b = [W b] [x] Wx+b = [W b] [x]
[1] [1]
*/ */
return moe_cuda_forward(input, gate, weight); return moe_cuda_forward(input, gate, weight1, weight2);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
......
...@@ -10,11 +10,11 @@ torch.cuda.manual_seed(42) ...@@ -10,11 +10,11 @@ torch.cuda.manual_seed(42)
class MOEFunction(Function): class MOEFunction(Function):
@staticmethod @staticmethod
def forward(ctx, inp, gate, weight): def forward(ctx, inp, gate, weight1, weight2):
out_feat, in_feat = weight.size()[1:] # out_feat, in_feat = weight.size()[1:]
weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat) # weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
output = moe_cuda.forward(inp, gate, weight_column_major) output = moe_cuda.forward(inp, gate, weight1, weight2)
variables = [inp, gate, weight_column_major] variables = [inp, gate, weight1, weight2]
ctx.save_for_backward(*variables) ctx.save_for_backward(*variables)
return output[0] return output[0]
...@@ -32,45 +32,59 @@ class MOEFunction(Function): ...@@ -32,45 +32,59 @@ class MOEFunction(Function):
class MOELayer(nn.Module): class MOELayer(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=4096): def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
super(MOELayer, self).__init__() super(MOELayer, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.hidden_feat = hidden_feat
self.out_feat = out_feat self.out_feat = out_feat
self.weight = nn.Parameter( self.weight1 = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat)) torch.Tensor(num_expert, hidden_feat, in_feat))
self.weight2 = nn.Parameter(
torch.Tensor(num_expert, out_feat, hidden_feat))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat) linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
self.weight.data[i] = linear.weight.data self.weight1.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
return MOEFunction.apply(inp, gate, self.weight) return MOEFunction.apply(inp, gate, self.weight1, self.weight2)
class MOELayer_raw(nn.Module): class MOELayer_raw(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=4096): def __init__(self, num_expert=32, in_feat=1024, hidden_feat=4096, out_feat=1024):
super(MOELayer_raw, self).__init__() super(MOELayer_raw, self).__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.hidden_feat = hidden_feat
self.out_feat = out_feat self.out_feat = out_feat
self.weight = nn.Parameter( self.weight1 = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat)) torch.Tensor(num_expert, hidden_feat, in_feat))
self.weight2 = nn.Parameter(
torch.Tensor(num_expert, out_feat, hidden_feat))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.out_feat) linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
self.weight.data[i] = linear.weight.data print(linear.weight.shape)
self.weight1.data[i] = linear.weight.data
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data
def forward(self, inp, gate): def forward(self, inp, gate):
gate_long = gate.long() gate_long = gate.long()
batch_size = inp.size(0) batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.out_feat)) x = inp.new_zeros((batch_size, self.out_feat))
print(self.weight2)
for i in range(batch_size): for i in range(batch_size):
x[i] = self.weight[gate_long[i]] @ inp[i] hid = inp[i] @ self.weight1[gate_long[i]].t()
print(hid)
x[i] = hid @ self.weight2[gate_long[i]].t()
return x return x
...@@ -80,6 +94,8 @@ def test_module(moe, linear, inp, gate): ...@@ -80,6 +94,8 @@ def test_module(moe, linear, inp, gate):
x = linear(inp) x = linear(inp)
output = moe(x, gate) output = moe(x, gate)
print(output) print(output)
return output
print(output)
y = output.mean() y = output.mean()
y.backward() y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
...@@ -87,15 +103,17 @@ def test_module(moe, linear, inp, gate): ...@@ -87,15 +103,17 @@ def test_module(moe, linear, inp, gate):
def test(): def test():
batch_size = 4 batch_size = 4
num_expert = 4 num_expert = 2
in_feat = 2 in_feat = 6
out_feat = 3 hidden_feat = 12
out_feat = 7
linear = nn.Linear(in_feat, in_feat).cuda() linear = nn.Linear(in_feat, in_feat).cuda()
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, hidden_feat, out_feat).cuda()
moe_raw = MOELayer_raw(num_expert, in_feat, out_feat).cuda() moe_raw = MOELayer_raw(num_expert, in_feat, hidden_feat, out_feat).cuda()
moe_raw.weight.data = moe.weight.data.clone() moe_raw.weight1.data = moe.weight1.data.clone()
moe_raw.weight2.data = moe.weight2.data.clone()
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
...@@ -104,6 +122,7 @@ def test(): ...@@ -104,6 +122,7 @@ def test():
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone()) raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias'] names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
names = ['Out']
for name, mo, ro in zip(names, moe_out, raw_out): for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum() err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err)) print('{} abs err {}'.format(name, err))
......
...@@ -10,17 +10,20 @@ ...@@ -10,17 +10,20 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <helper_cuda.h> #include <helper_cuda.h>
// #include "timer.hh" #include "timer.hh"
#include "cublas_wrapper.h" #include "cublas_wrapper.h"
#include "cuda_stream_manager.h" #include "cuda_stream_manager.h"
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_BREAKDOWN
#define MOE_DEBUG
template <typename scalar_t> template <typename scalar_t>
__global__ __global__
void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, const int* offset, const scalar_t** ptrs) { void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride,
const int* offset, const scalar_t** ptrs) {
size_t idx = threadIdx.x + blockDim.x * blockIdx.x; size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx < n) { if (idx < n) {
ptrs[idx] = base + stride * offset[idx]; ptrs[idx] = base + stride * offset[idx];
...@@ -32,22 +35,35 @@ template <typename scalar_t> ...@@ -32,22 +35,35 @@ template <typename scalar_t>
void moe_cuda_forward_impl( void moe_cuda_forward_impl(
const scalar_t* input, const scalar_t* input,
const int* d_gate, const int* d_gate,
const scalar_t* weight, const scalar_t* weight1,
const scalar_t* weight2,
scalar_t* output, scalar_t* output,
const size_t batch_size, const size_t batch_size,
const size_t in_feat, const size_t in_feat,
const size_t hidden_feat,
const size_t out_feat, const size_t out_feat,
const size_t num_expert, const size_t num_expert) {
cublasOperation_t transb) {
auto h = getCudaStreamManager(num_expert); auto h = getCudaStreamManager(num_expert);
scalar_t *input_buf, *output_buf; #ifdef MOE_BREAKDOWN
timestamp(t_init);
#endif
scalar_t *input_buf, *hidden_buf, *output_buf;
checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
in_feat)); in_feat));
checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size * checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
out_feat)); out_feat));
checkCudaErrors(cudaMalloc(&hidden_buf, sizeof(scalar_t) * batch_size *
hidden_feat));
#ifdef MOE_BREAKDOWN
timestamp(t_malloc);
fprintf(stderr, "Malloc time %.3lf us\n", getDuration(t_init, t_malloc) *
1e6);
#endif
int *gate = new int[batch_size]; int *gate = new int[batch_size];
int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert]; int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
...@@ -55,6 +71,13 @@ void moe_cuda_forward_impl( ...@@ -55,6 +71,13 @@ void moe_cuda_forward_impl(
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size, checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
#ifdef MOE_BREAKDOWN
timestamp(t_cpy);
fprintf(stderr, "Copy time %.3lf us\n", getDuration(t_malloc, t_cpy) *
1e6);
#endif
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]]; ++expert_count[gate[i]];
} }
...@@ -62,6 +85,13 @@ void moe_cuda_forward_impl( ...@@ -62,6 +85,13 @@ void moe_cuda_forward_impl(
for (int i = 1; i < num_expert; ++i) { for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1]; expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
} }
#ifdef MOE_BREAKDOWN
timestamp(t_expert);
fprintf(stderr, "Expert asn time %.3lf us\n", getDuration(t_cpy, t_expert) *
1e6);
#endif
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int target_idx = expert_ptr[gate[i]]++; int target_idx = expert_ptr[gate[i]]++;
#ifdef MOE_DEBUG_SCATTER #ifdef MOE_DEBUG_SCATTER
...@@ -73,6 +103,13 @@ void moe_cuda_forward_impl( ...@@ -73,6 +103,13 @@ void moe_cuda_forward_impl(
h->getStream(gate[i]))); h->getStream(gate[i])));
} }
#ifdef MOE_BREAKDOWN
h->sync();
timestamp(t_scatter);
fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
1e6);
#endif
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
...@@ -86,19 +123,37 @@ void moe_cuda_forward_impl( ...@@ -86,19 +123,37 @@ void moe_cuda_forward_impl(
#endif #endif
// Use T(B) x T(A) = T(C) to produce row-major C // Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->getHandle(i), checkCudaErrors(cublasXgemm(h->getHandle(i),
(transb == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
out_feat, expert_count[i], in_feat, hidden_feat, expert_count[i], in_feat,
&alpha, &alpha,
weight + i * in_feat * out_feat, weight1 + i * in_feat * hidden_feat, in_feat,
(transb == CUBLAS_OP_T) ? out_feat : in_feat,
input_buf + ptr * in_feat, in_feat, input_buf + ptr * in_feat, in_feat,
&beta, &beta,
output_buf + out_feat * ptr, hidden_buf + hidden_feat * ptr, hidden_feat
out_feat
)); ));
checkCudaErrors(cublasXgemm(h->getHandle(i),
CUBLAS_OP_T,
CUBLAS_OP_N,
out_feat, expert_count[i], hidden_feat,
&alpha,
weight2 + i * hidden_feat * out_feat, hidden_feat,
hidden_buf + hidden_feat * ptr, hidden_feat,
&beta,
output_buf + out_feat * ptr, out_feat
));
ptr += expert_count[i]; ptr += expert_count[i];
} }
#ifdef MOE_BREAKDOWN
h->sync();
timestamp(t_mm);
fprintf(stderr, "GeMM time %.3lf us\n", getDuration(t_scatter, t_mm) *
1e6);
#endif
for (int i = batch_size - 1; i >= 0; --i) { for (int i = batch_size - 1; i >= 0; --i) {
int target_idx = --expert_ptr[gate[i]]; int target_idx = --expert_ptr[gate[i]];
#ifdef MOE_DEBUG_SCATTER #ifdef MOE_DEBUG_SCATTER
...@@ -113,6 +168,14 @@ void moe_cuda_forward_impl( ...@@ -113,6 +168,14 @@ void moe_cuda_forward_impl(
h->sync(); h->sync();
#ifdef MOE_BREAKDOWN
timestamp(t_gather);
fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
1e6);
fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_gather) *
1e6);
#endif
cudaFree(input_buf); cudaFree(input_buf);
cudaFree(output_buf); cudaFree(output_buf);
} }
...@@ -159,14 +222,17 @@ void moe_cuda_grad_weight( ...@@ -159,14 +222,17 @@ void moe_cuda_grad_weight(
std::vector<torch::Tensor> moe_cuda_forward( std::vector<torch::Tensor> moe_cuda_forward(
torch::Tensor input, torch::Tensor input,
torch::Tensor gate, torch::Tensor gate,
torch::Tensor weight) { torch::Tensor weight1,
torch::Tensor weight2
) {
const auto batch_size = input.size(0); const auto batch_size = input.size(0);
const auto num_expert = weight.size(0); const auto num_expert = weight1.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight2.size(1);
const auto in_feat = weight.size(2); const auto hidden_feat = weight1.size(1);
const auto in_feat = weight1.size(2);
#ifdef MOE_DEBUG #ifdef MOE_DEBUG
printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[forward] b=%ld, expert=%ld, in_feat (d_model)=%ld, hidden_feat = %ld,out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, hidden_feat, out_feat);
#endif #endif
auto output = input.new_zeros({batch_size, out_feat}); auto output = input.new_zeros({batch_size, out_feat});
...@@ -174,13 +240,14 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -174,13 +240,14 @@ std::vector<torch::Tensor> moe_cuda_forward(
moe_cuda_forward_impl<scalar_t>( moe_cuda_forward_impl<scalar_t>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
gate.data_ptr<int>(), gate.data_ptr<int>(),
weight.data_ptr<scalar_t>(), weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
batch_size, batch_size,
in_feat, in_feat,
hidden_feat,
out_feat, out_feat,
num_expert, num_expert
CUBLAS_OP_T
); );
})); }));
...@@ -205,6 +272,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -205,6 +272,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
// grad_input is easy to compute, exactly the same as forward // grad_input is easy to compute, exactly the same as forward
/* TODO: Backward currently brokenn
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_forward_impl<scalar_t>( moe_cuda_forward_impl<scalar_t>(
grad_output.data_ptr<scalar_t>(), grad_output.data_ptr<scalar_t>(),
...@@ -218,6 +286,7 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -218,6 +286,7 @@ std::vector<torch::Tensor> moe_cuda_backward(
CUBLAS_OP_N CUBLAS_OP_N
); );
})); }));
*/
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_grad_weight<scalar_t>( moe_cuda_grad_weight<scalar_t>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment