Commit d4dd2a6c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

debuging

parent 93291a7e
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import moe_cuda import moe_cuda
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed(42)
class MOEFunction(Function): class MOEFunction(Function):
@staticmethod @staticmethod
...@@ -27,29 +27,70 @@ class MOEFunction(Function): ...@@ -27,29 +27,70 @@ class MOEFunction(Function):
class MOELayer(nn.Module): class MOELayer(nn.Module):
def __init__(self, num_expert=32, in_feat=1024, out_feat=4096): def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
super(MOELayer, self).__init__() super(MOELayer, self).__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat)) torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
pass for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, input, gate): def forward(self, input, gate):
return MOEFunction.apply(input, gate, self.weight) return MOEFunction.apply(input, gate, self.weight)
batch_size = 64 class MOELayer_einsum(nn.Module):
num_expert = 32 def __init__(self, num_expert=32, in_feat=1024, out_feat=4096):
in_feat = 512 super(MOELayer_einsum, self).__init__()
out_feat = 512 self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.weight = nn.Parameter(
torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters()
def reset_parameters(self):
for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=out_feat)
self.weight.data[i] = linear.weight.data
def forward(self, input, gate):
gate_long = gate.long()
#W = self.weight[gate_long] # [batch_size x out_feat x in_feat]
#x = torch.einsum('id,ihd->ih', (input, W)) # [batch_size x out_feat]
#return x
batch_size = input.size(0)
x = input.new_zeros((batch_size, self.out_feat))
for i in range(batch_size):
x[i] = self.weight[gate_long[i]] @ input[i]
return x
batch_size = 2
num_expert = 2
in_feat = 2
out_feat = 4
moe = MOELayer(num_expert, in_feat, out_feat).cuda() moe = MOELayer(num_expert, in_feat, out_feat).cuda()
moe_einsum = MOELayer_einsum(num_expert, in_feat, out_feat).cuda()
moe_einsum.weight.data = moe.weight.data.clone()
input = torch.rand(batch_size, in_feat).cuda() input = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
print(input)
print(gate)
output = moe(input, gate) output = moe(input, gate)
print(input)
print(gate)
output_einsum = moe_einsum(input, gate)
print(output)
print(output_einsum)
y = output.mean() #y = output.mean()
y.backward() #y.backward()
\ No newline at end of file \ No newline at end of file
...@@ -151,7 +151,7 @@ void moe_cuda_forward_impl( ...@@ -151,7 +151,7 @@ void moe_cuda_forward_impl(
checkCudaErrors(cublasSetStream(h->handle, *(h->streams))); checkCudaErrors(cublasSetStream(h->handle, *(h->streams)));
// setup Aarray, Barray and Carray // setup Aarray, Barray and Carray
std::vector<const scalar_t*> aptrs; std::vector<const scalar_t*> aptrs, bptrs;
std::vector<scalar_t*> cptrs; std::vector<scalar_t*> cptrs;
const scalar_t **Aarray; const scalar_t **Aarray;
...@@ -163,6 +163,7 @@ void moe_cuda_forward_impl( ...@@ -163,6 +163,7 @@ void moe_cuda_forward_impl(
for (size_t i=0; i<batch_size; ++i) { for (size_t i=0; i<batch_size; ++i) {
aptrs.push_back(input + in_feat * i); aptrs.push_back(input + in_feat * i);
bptrs.push_back(weight + out_feat * in_feat * i);
cptrs.push_back(output + out_feat * i); cptrs.push_back(output + out_feat * i);
} }
checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const scalar_t*), cudaMemcpyHostToDevice)); checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const scalar_t*), cudaMemcpyHostToDevice));
...@@ -173,14 +174,23 @@ void moe_cuda_forward_impl( ...@@ -173,14 +174,23 @@ void moe_cuda_forward_impl(
dim3 blockdim(256); dim3 blockdim(256);
generate_ptr_offset_kernel<<<griddim, blockdim, 0, *(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray); generate_ptr_offset_kernel<<<griddim, blockdim, 0, *(h->streams)>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
const scalar_t **B = (const scalar_t **)malloc(batch_size * sizeof(const scalar_t*));
checkCudaErrors(cudaMemcpy(B, Barray, batch_size * sizeof(const scalar_t*), cudaMemcpyDeviceToHost));
std::cout << weight << std::endl;
for (size_t i=0; i<batch_size; ++i) {
std::cout << B[i] << " " << bptrs[i] << std::endl;
}
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
checkCudaErrors(cublasXgemmBatched(h->handle, checkCudaErrors(cublasXgemmBatched(h->handle,
CUBLAS_OP_N, CUBLAS_OP_N,
transb, transb,
1, out_feat, in_feat, 1, out_feat, in_feat,
&alpha, &alpha,
Aarray, 1, Aarray, 1,
Barray, out_feat, Barray, (transb == CUBLAS_OP_T) ? out_feat : in_feat,
&beta, &beta,
Carray, 1, Carray, 1,
batch_size)); batch_size));
...@@ -234,7 +244,7 @@ std::vector<torch::Tensor> moe_cuda_forward( ...@@ -234,7 +244,7 @@ std::vector<torch::Tensor> moe_cuda_forward(
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
// printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k); printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n", batch_size, num_expert, in_feat, out_feat);
auto output = input.new_zeros({batch_size, out_feat}); auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_forward_cuda", ([&] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment