Commit 9cae3294 authored by Rick Ho's avatar Rick Ho
Browse files

naive memcpy batch with buffer

parent 8cff6ad7
#ifndef CUBLAS_WRAPPER_H #ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H #define CUBLAS_WRAPPER_H
#include <cublas_v2.h> #include <cublas_v2.h>
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle, inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transa,
......
...@@ -6,24 +6,27 @@ ...@@ -6,24 +6,27 @@
#include <helper_cuda.h> #include <helper_cuda.h>
class CudaStreamManager { struct CudaStreamManager {
public: const size_t num_expert;
cublasHandle_t* handles;
cudaStream_t* streams;
CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) { CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) {
streams = new cudaStream_t[num_expert]; streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle)); handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i)); checkCudaErrors(cublasCreate(handles + i));
} checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasSetStream(handles[i], streams[i]));
}
} }
~CudaStreamManager() { ~CudaStreamManager() {
for (size_t i=0; i<num_expert; ++i) { for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i))); checkCudaErrors(cudaStreamDestroy(streams[i]));
} checkCudaErrors(cublasDestroy(handles[i]));
checkCudaErrors(cublasDestroy(handle)); }
} }
const size_t num_expert;
cublasHandle_t handle;
cudaStream_t* streams;
}; };
CudaStreamManager* getCudaStreamManager(const size_t num_expert); CudaStreamManager* getCudaStreamManager(const size_t num_expert);
......
...@@ -21,12 +21,12 @@ class MOEFunction(Function): ...@@ -21,12 +21,12 @@ class MOEFunction(Function):
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
print("grad_out", grad_out) # print("grad_out", grad_out)
print("input", ctx.saved_tensors[0]) # print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward( grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors) grad_out.contiguous(), *ctx.saved_tensors)
out_feat, in_feat = grad_weight.size()[1:] out_feat, in_feat = grad_weight.size()[1:]
print("grad_weight_column_major", grad_weight.flatten()) # print("grad_weight_column_major", grad_weight.flatten())
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat) grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major return grad_inp, None, grad_weight_row_major
...@@ -74,6 +74,17 @@ class MOELayer_raw(nn.Module): ...@@ -74,6 +74,17 @@ class MOELayer_raw(nn.Module):
return x return x
def test_module(moe, linear, inp, gate):
linear.zero_grad()
moe.zero_grad()
x = linear(inp)
output = moe(x, gate)
print(output)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test(): def test():
batch_size = 4 batch_size = 4
num_expert = 4 num_expert = 4
...@@ -89,28 +100,13 @@ def test(): ...@@ -89,28 +100,13 @@ def test():
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
linear.zero_grad() moe_out = test_module(moe, linear, inp.clone(), gate.clone())
moe.zero_grad() raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
x = linear(inp)
output = moe(x, gate)
print("moe output", output)
y = output.mean()
y.backward()
print("moe.weight.grad", moe.weight.grad)
print("linear.weight.grad", linear.weight.grad)
print("linear.bias.grad", linear.bias.grad)
linear.zero_grad() names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
moe.zero_grad() for name, mo, ro in zip(names, moe_out, raw_out):
x = linear(inp.clone()) err = (mo - ro).abs().sum()
output_raw= moe_raw(x, gate.clone()) print('{} abs err {}'.format(name, err))
print("moe_raw output", output_raw)
y_raw = output_raw.mean()
y_raw.backward()
print("moe_raw.weight.grad", moe_raw.weight.grad)
print("linear_raw.weight.grad", linear.weight.grad)
print("linear_raw.bias.grad", linear.bias.grad)
if __name__ == '__main__': if __name__ == '__main__':
test() test()
...@@ -31,7 +31,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, c ...@@ -31,7 +31,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, c
template <typename scalar_t> template <typename scalar_t>
void moe_cuda_forward_impl( void moe_cuda_forward_impl(
const scalar_t* input, const scalar_t* input,
const int* gate, const int* d_gate,
const scalar_t* weight, const scalar_t* weight,
scalar_t* output, scalar_t* output,
const size_t batch_size, const size_t batch_size,
...@@ -40,50 +40,82 @@ void moe_cuda_forward_impl( ...@@ -40,50 +40,82 @@ void moe_cuda_forward_impl(
const size_t num_expert, const size_t num_expert,
cublasOperation_t transb) { cublasOperation_t transb) {
auto* h = getCudaStreamManager(num_expert); auto h = getCudaStreamManager(num_expert);
scalar_t *input_buf, *output_buf;
checkCudaErrors(cublasSetStream(h->handle, *(h->streams))); checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
in_feat));
checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
out_feat));
// setup Aarray, Barray and Carray int *gate = new int[batch_size];
std::vector<const scalar_t*> aptrs; int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
std::vector<scalar_t*> cptrs; memset(expert_count, 0, sizeof(int) * num_expert);
const scalar_t **Aarray;
const scalar_t **Barray;
scalar_t **Carray;
checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*)));
checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
for (size_t i=0; i<batch_size; ++i) { checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
aptrs.push_back(input + in_feat * i); cudaMemcpyDeviceToHost));
cptrs.push_back(output + out_feat * i); for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]];
}
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
for (int i = 0; i < batch_size; ++i) {
int target_idx = expert_ptr[gate[i]]++;
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "aln idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat,
input + i * in_feat, sizeof(scalar_t) * in_feat,
cudaMemcpyDeviceToDevice,
h->streams[gate[i]]));
} }
checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
scalar_t*), cudaMemcpyHostToDevice));
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
sizeof(scalar_t*), cudaMemcpyHostToDevice));
dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
generate_ptr_offset_kernel<<<griddim, blockdim, 0,
*(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
checkCudaErrors(cublasXgemmBatched(h->handle, for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
continue;
}
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->handles[i],
(transb == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
transb, out_feat, expert_count[i], in_feat,
1, out_feat, in_feat,
&alpha, &alpha,
Aarray, 1, weight + i * in_feat * out_feat,
Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
input_buf + ptr * in_feat, in_feat,
&beta, &beta,
Carray, 1, output_buf + out_feat * ptr,
batch_size)); out_feat
));
ptr += expert_count[i];
}
for (int i = batch_size - 1; i >= 0; --i) {
int target_idx = --expert_ptr[gate[i]];
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "cb idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
checkCudaErrors(cudaMemcpyAsync(output + i * out_feat,
output_buf + target_idx * out_feat,
sizeof(scalar_t) * out_feat,
cudaMemcpyDeviceToDevice,
h->streams[gate[i]]));
}
checkCudaErrors(cudaStreamSynchronize(*(h->streams))); for (int i = 0; i < num_expert; ++i) {
cudaStreamSynchronize(h->streams[i]);
}
cudaFree(input_buf);
cudaFree(output_buf);
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -103,8 +135,8 @@ void moe_cuda_grad_weight( ...@@ -103,8 +135,8 @@ void moe_cuda_grad_weight(
scalar_t alpha = 1, beta = 1; scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost)); checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) { for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i]))); checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handle, checkCudaErrors(cublasXgemm(h->handles[0],
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
out_feat, out_feat,
...@@ -166,7 +198,9 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -166,7 +198,9 @@ std::vector<torch::Tensor> moe_cuda_backward(
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat); printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment