Commit 861b75c1 authored by Rick Ho's avatar Rick Ho
Browse files

remove weight input to c expert count function

parent 60b93e39
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
#include <vector> #include <vector>
std::vector<torch::Tensor> moe_cuda_expert_count( std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor weight, // TODO: pass num-experts in another way? torch::Tensor gate, size_t num_expert);
torch::Tensor gate);
std::vector<torch::Tensor> moe_cuda_local_scatter( std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input, torch::Tensor input,
...@@ -35,10 +34,10 @@ std::vector<torch::Tensor> moe_cuda_backward( ...@@ -35,10 +34,10 @@ std::vector<torch::Tensor> moe_cuda_backward(
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count( std::vector<torch::Tensor> moe_expert_count(
torch::Tensor weight, torch::Tensor gate,
torch::Tensor gate) { size_t num_expert) {
CHECK_INPUT(gate); CHECK_INPUT(gate);
return moe_cuda_expert_count(weight, gate); return moe_cuda_expert_count(gate, num_expert);
} }
std::vector<torch::Tensor> moe_local_scatter( std::vector<torch::Tensor> moe_local_scatter(
......
...@@ -11,7 +11,7 @@ class MOEFunction(Function): ...@@ -11,7 +11,7 @@ class MOEFunction(Function):
def forward(ctx, inp, gate, weight): def forward(ctx, inp, gate, weight):
# out_feat, in_feat = weight.size()[1:] # out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat) # weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
expert_count, pos = moe_cuda.expert_count(weight, gate) expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
input_buf, = moe_cuda.local_scatter(inp, pos) input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count) output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos) output = moe_cuda.local_gather(output_buf, pos)
......
...@@ -199,10 +199,9 @@ void moe_cuda_backward_impl( ...@@ -199,10 +199,9 @@ void moe_cuda_backward_impl(
std::vector<torch::Tensor> moe_cuda_expert_count( std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor weight, torch::Tensor gate,
torch::Tensor gate) { size_t num_expert) {
const auto batch_size = gate.size(0); const auto batch_size = gate.size(0);
const auto num_expert = weight.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32); auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options); auto expert_count = torch::empty(num_expert, ec_options);
......
...@@ -3,7 +3,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7 ...@@ -3,7 +3,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ] if [ -z $1 ]
then then
python moe.py python3 moe.py
elif [ .$1 = '.test_all' ] elif [ .$1 = '.test_all' ]
then then
for nexp in 1 2 4 for nexp in 1 2 4
...@@ -15,11 +15,11 @@ then ...@@ -15,11 +15,11 @@ then
for bs in 4 16 64 256 512 1024 2048 4096 for bs in 4 16 64 256 512 1024 2048 4096
do do
echo $bs $nexp ${inf}x${ouf} echo $bs $nexp ${inf}x${ouf}
python moe_test.py $bs $inf $ouf $nexp python3 moe_test.py $bs $inf $ouf $nexp
done done
done done
done done
done done
else else
python $@ python3 $@
fi fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment