"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3afcf3cd49661c466c75ea536b0b2a7ff57f9a05"
Commit eb47044a authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

topk=1

parent 3a458fa7
...@@ -18,7 +18,7 @@ std::vector<torch::Tensor> moe1_cuda_forward( ...@@ -18,7 +18,7 @@ std::vector<torch::Tensor> moe1_cuda_forward(
std::vector<torch::Tensor> moe1_forward( std::vector<torch::Tensor> moe1_forward(
torch::Tensor input, // [B x D_model] torch::Tensor input, // [B x D_model]
torch::Tensor gate, // [B x K] torch::Tensor gate, // [B]
torch::Tensor weight // [N x D_ffn x D_model] torch::Tensor weight // [N x D_ffn x D_model]
) { ) {
CHECK_INPUT(input); CHECK_INPUT(input);
......
...@@ -74,7 +74,6 @@ void moe1_cuda_forward_impl( ...@@ -74,7 +74,6 @@ void moe1_cuda_forward_impl(
const scalar_t* weight, const scalar_t* weight,
scalar_t* output, scalar_t* output,
const size_t batch_size, const size_t batch_size,
const size_t top_k,
const size_t in_feat, const size_t in_feat,
const size_t out_feat) { const size_t out_feat) {
...@@ -91,24 +90,21 @@ void moe1_cuda_forward_impl( ...@@ -91,24 +90,21 @@ void moe1_cuda_forward_impl(
const scalar_t **Aarray; const scalar_t **Aarray;
const scalar_t **Barray; const scalar_t **Barray;
scalar_t **Carray; scalar_t **Carray;
checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*) * top_k)); checkCudaErrors(cudaMalloc(&Aarray, batch_size * sizeof(const scalar_t*)));
checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*) * top_k)); checkCudaErrors(cudaMalloc(&Barray, batch_size * sizeof(const scalar_t*)));
checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*) * top_k)); checkCudaErrors(cudaMalloc(&Carray, batch_size * sizeof(scalar_t*)));
for (size_t i=0; i<batch_size; ++i) { for (size_t i=0; i<batch_size; ++i) {
for (size_t k=0; k<top_k; ++k) { aptrs.push_back(input + in_feat * i);
aptrs.push_back(input + in_feat * i); cptrs.push_back(output + out_feat * i);
// bptrs.push_back(weight + out_feat * in_feat * gate[i * top_k + k]);
cptrs.push_back(output + out_feat * (i * top_k + k));
}
} }
checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const scalar_t*) * top_k, cudaMemcpyHostToDevice)); checkCudaErrors(cudaMemcpy(Aarray, aptrs.data(), batch_size * sizeof(const scalar_t*), cudaMemcpyHostToDevice));
// checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice)); // checkCudaErrors(cudaMemcpy(ptrs + batch_size * top_k, bptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice));
checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size * sizeof(scalar_t*) * top_k, cudaMemcpyHostToDevice)); checkCudaErrors(cudaMemcpy(Carray, cptrs.data(), batch_size * sizeof(scalar_t*), cudaMemcpyHostToDevice));
dim3 griddim(CEIL(batch_size * top_k, 256)); dim3 griddim(CEIL(batch_size, 256));
dim3 blockdim(256); dim3 blockdim(256);
generate_ptr_offset_kernel<<<griddim, blockdim, 0, st>>>(batch_size * top_k, weight, out_feat * in_feat, gate, Barray); generate_ptr_offset_kernel<<<griddim, blockdim, 0, st>>>(batch_size, weight, out_feat * in_feat, gate, Barray);
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
checkCudaErrors(cublasXgemmBatched(handle, checkCudaErrors(cublasXgemmBatched(handle,
...@@ -120,7 +116,7 @@ void moe1_cuda_forward_impl( ...@@ -120,7 +116,7 @@ void moe1_cuda_forward_impl(
Barray, out_feat, Barray, out_feat,
&beta, &beta,
Carray, 1, Carray, 1,
batch_size * top_k)); batch_size));
checkCudaErrors(cudaStreamSynchronize(st)); checkCudaErrors(cudaStreamSynchronize(st));
checkCudaErrors(cudaStreamDestroy(st)); checkCudaErrors(cudaStreamDestroy(st));
...@@ -133,13 +129,12 @@ std::vector<torch::Tensor> moe1_cuda_forward( ...@@ -133,13 +129,12 @@ std::vector<torch::Tensor> moe1_cuda_forward(
torch::Tensor gate, torch::Tensor gate,
torch::Tensor weight) { torch::Tensor weight) {
const auto batch_size = input.size(0); const auto batch_size = input.size(0);
const auto top_k = gate.size(1);
const auto num_expert = weight.size(0); const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1); const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2); const auto in_feat = weight.size(2);
// printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k); // printf("b=%ld, expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld, topk=%ld\n", batch_size, num_expert, in_feat, out_feat, top_k);
auto output = input.new_zeros({batch_size, top_k, out_feat}); auto output = input.new_zeros({batch_size, out_feat});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe1_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe1_forward_cuda", ([&] {
moe1_cuda_forward_impl<scalar_t>( moe1_cuda_forward_impl<scalar_t>(
...@@ -148,7 +143,6 @@ std::vector<torch::Tensor> moe1_cuda_forward( ...@@ -148,7 +143,6 @@ std::vector<torch::Tensor> moe1_cuda_forward(
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
batch_size, batch_size,
top_k,
in_feat, in_feat,
out_feat out_feat
); );
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment