Commit 9cae3294 authored by Rick Ho's avatar Rick Ho
Browse files

naive memcpy batch with buffer

parent 8cff6ad7
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
#include <cublas_v2.h>
#include <cublas_v2.h>
inline cublasStatus_t cublasXgemmBatched(cublasHandle_t handle,
cublasOperation_t transa,
......
......@@ -6,24 +6,27 @@
#include <helper_cuda.h>
class CudaStreamManager {
public:
struct CudaStreamManager {
const size_t num_expert;
cublasHandle_t* handles;
cudaStream_t* streams;
CudaStreamManager(const size_t num_expert_) : num_expert(num_expert_) {
streams = new cudaStream_t[num_expert];
checkCudaErrors(cublasCreate(&handle));
handles = new cublasHandle_t[num_expert];
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamCreate(streams+i));
}
checkCudaErrors(cublasCreate(handles + i));
checkCudaErrors(cudaStreamCreate(streams + i));
checkCudaErrors(cublasSetStream(handles[i], streams[i]));
}
}
~CudaStreamManager() {
for (size_t i=0; i<num_expert; ++i) {
checkCudaErrors(cudaStreamDestroy(*(streams+i)));
}
checkCudaErrors(cublasDestroy(handle));
checkCudaErrors(cudaStreamDestroy(streams[i]));
checkCudaErrors(cublasDestroy(handles[i]));
}
}
const size_t num_expert;
cublasHandle_t handle;
cudaStream_t* streams;
};
CudaStreamManager* getCudaStreamManager(const size_t num_expert);
......
......@@ -21,12 +21,12 @@ class MOEFunction(Function):
@staticmethod
def backward(ctx, grad_out):
print("grad_out", grad_out)
print("input", ctx.saved_tensors[0])
# print("grad_out", grad_out)
# print("input", ctx.saved_tensors[0])
grad_inp, grad_weight = moe_cuda.backward(
grad_out.contiguous(), *ctx.saved_tensors)
out_feat, in_feat = grad_weight.size()[1:]
print("grad_weight_column_major", grad_weight.flatten())
# print("grad_weight_column_major", grad_weight.flatten())
grad_weight_row_major = grad_weight.view(-1, in_feat, out_feat).transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
return grad_inp, None, grad_weight_row_major
......@@ -74,6 +74,17 @@ class MOELayer_raw(nn.Module):
return x
def test_module(moe, linear, inp, gate):
linear.zero_grad()
moe.zero_grad()
x = linear(inp)
output = moe(x, gate)
print(output)
y = output.mean()
y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
def test():
batch_size = 4
num_expert = 4
......@@ -89,28 +100,13 @@ def test():
inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
linear.zero_grad()
moe.zero_grad()
x = linear(inp)
output = moe(x, gate)
print("moe output", output)
y = output.mean()
y.backward()
print("moe.weight.grad", moe.weight.grad)
print("linear.weight.grad", linear.weight.grad)
print("linear.bias.grad", linear.bias.grad)
moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
linear.zero_grad()
moe.zero_grad()
x = linear(inp.clone())
output_raw= moe_raw(x, gate.clone())
print("moe_raw output", output_raw)
y_raw = output_raw.mean()
y_raw.backward()
print("moe_raw.weight.grad", moe_raw.weight.grad)
print("linear_raw.weight.grad", linear.weight.grad)
print("linear_raw.bias.grad", linear.bias.grad)
names = ['Out', 'Moe wei', 'Linear wei', 'Linear bias']
for name, mo, ro in zip(names, moe_out, raw_out):
err = (mo - ro).abs().sum()
print('{} abs err {}'.format(name, err))
if __name__ == '__main__':
test()
......@@ -31,7 +31,7 @@ void generate_ptr_offset_kernel(size_t n, const scalar_t* base, size_t stride, c
template <typename scalar_t>
void moe_cuda_forward_impl(
const scalar_t* input,
const int* gate,
const int* d_gate,
const scalar_t* weight,
scalar_t* output,
const size_t batch_size,
......@@ -40,50 +40,82 @@ void moe_cuda_forward_impl(
const size_t num_expert,
cublasOperation_t transb) {
auto* h = getCudaStreamManager(num_expert);
auto h = getCudaStreamManager(num_expert);
scalar_t *input_buf, *output_buf;
checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
checkCudaErrors(cudaMalloc(&input_buf, sizeof(scalar_t) * batch_size *
in_feat));
checkCudaErrors(cudaMalloc(&output_buf, sizeof(scalar_t) * batch_size *
out_feat));
// setup Aarray, Barray and Carray
std::vector<const scalar_t*> aptrs;
std::vector<scalar_t*> cptrs;
const scalar_t **Aarray;
const scalar_t **Barray;
scalar_t **Carray;
checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*)));
checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
int *gate = new int[batch_size];
int *expert_count = new int[num_expert], *expert_ptr = new int[num_expert];
memset(expert_count, 0, sizeof(int) * num_expert);
for (size_t i=0; i<batch_size; ++i) {
aptrs.push_back(input + in_feat * i);
cptrs.push_back(output + out_feat * i);
checkCudaErrors(cudaMemcpy(gate, d_gate, sizeof(int) * batch_size,
cudaMemcpyDeviceToHost));
for (int i = 0; i < batch_size; ++i) {
++expert_count[gate[i]];
}
expert_ptr[0] = 0;
for (int i = 1; i < num_expert; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + expert_count[i - 1];
}
for (int i = 0; i < batch_size; ++i) {
int target_idx = expert_ptr[gate[i]]++;
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "aln idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
checkCudaErrors(cudaMemcpyAsync(input_buf + target_idx * in_feat,
input + i * in_feat, sizeof(scalar_t) * in_feat,
cudaMemcpyDeviceToDevice,
h->streams[gate[i]]));
}
checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const
scalar_t*), cudaMemcpyHostToDevice));
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(),
// batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size *
sizeof(scalar_t*), cudaMemcpyHostToDevice));
dim3 griddim(CEIL(batch_size, 256)); dim3 blockdim(256);
generate_ptr_offset_kernel<<<griddim, blockdim, 0,
*(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
scalar_t alpha = 1, beta = 0;
checkCudaErrors(cublasXgemmBatched(h->handle,
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
continue;
}
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "gemm %d sz %d\n", i, expert_count[i]);
fprintf(stderr, "GeMM %d x %d x %d\n", out_feat, expert_count[i],
in_feat);
#endif
// Use T(B) x T(A) = T(C) to produce row-major C
checkCudaErrors(cublasXgemm(h->handles[i],
(transb == CUBLAS_OP_T) ? CUBLAS_OP_N : CUBLAS_OP_T,
CUBLAS_OP_N,
transb,
1, out_feat, in_feat,
out_feat, expert_count[i], in_feat,
&alpha,
Aarray, 1,
Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
weight + i * in_feat * out_feat,
(transb == CUBLAS_OP_T) ? out_feat : in_feat,
input_buf + ptr * in_feat, in_feat,
&beta,
Carray, 1,
batch_size));
output_buf + out_feat * ptr,
out_feat
));
ptr += expert_count[i];
}
for (int i = batch_size - 1; i >= 0; --i) {
int target_idx = --expert_ptr[gate[i]];
#ifdef MOE_DEBUG_SCATTER
fprintf(stderr, "cb idx %d gate %d tgt %d\n", i, gate[i], target_idx);
#endif
checkCudaErrors(cudaMemcpyAsync(output + i * out_feat,
output_buf + target_idx * out_feat,
sizeof(scalar_t) * out_feat,
cudaMemcpyDeviceToDevice,
h->streams[gate[i]]));
}
checkCudaErrors(cudaStreamSynchronize(*(h->streams)));
for (int i = 0; i < num_expert; ++i) {
cudaStreamSynchronize(h->streams[i]);
}
cudaFree(input_buf);
cudaFree(output_buf);
}
template <typename scalar_t>
......@@ -103,8 +135,8 @@ void moe_cuda_grad_weight(
scalar_t alpha = 1, beta = 1;
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handle,
checkCudaErrors(cublasSetStream(h->handles[0], *(h->streams + gate_host[i])));
checkCudaErrors(cublasXgemm(h->handles[0],
CUBLAS_OP_N,
CUBLAS_OP_T,
out_feat,
......@@ -166,7 +198,9 @@ std::vector<torch::Tensor> moe_cuda_backward(
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
#endif
auto grad_input = grad_output.new_zeros({batch_size, in_feat}); // batch_size x in_feat
auto grad_weight = grad_output.new_zeros({num_expert, out_feat, in_feat}); // num_expert x out_feat x in_feat
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment