Commit 861b75c1 authored by Rick Ho's avatar Rick Ho
Browse files

remove weight input to c expert count function

parent 60b93e39
......@@ -5,8 +5,7 @@
#include <vector>
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor weight, // TODO: pass num-experts in another way?
torch::Tensor gate);
torch::Tensor gate, size_t num_expert);
std::vector<torch::Tensor> moe_cuda_local_scatter(
torch::Tensor input,
......@@ -35,10 +34,10 @@ std::vector<torch::Tensor> moe_cuda_backward(
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count(
torch::Tensor weight,
torch::Tensor gate) {
torch::Tensor gate,
size_t num_expert) {
CHECK_INPUT(gate);
return moe_cuda_expert_count(weight, gate);
return moe_cuda_expert_count(gate, num_expert);
}
std::vector<torch::Tensor> moe_local_scatter(
......
......@@ -11,7 +11,7 @@ class MOEFunction(Function):
def forward(ctx, inp, gate, weight):
# out_feat, in_feat = weight.size()[1:]
# weight_column_major = weight.transpose(-1, -2).contiguous().view(-1, out_feat, in_feat)
expert_count, pos = moe_cuda.expert_count(weight, gate)
expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
input_buf, = moe_cuda.local_scatter(inp, pos)
output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
output = moe_cuda.local_gather(output_buf, pos)
......
......@@ -199,10 +199,9 @@ void moe_cuda_backward_impl(
std::vector<torch::Tensor> moe_cuda_expert_count(
torch::Tensor weight,
torch::Tensor gate) {
torch::Tensor gate,
size_t num_expert) {
const auto batch_size = gate.size(0);
const auto num_expert = weight.size(0);
auto ec_options = torch::TensorOptions().dtype(torch::kInt32);
auto expert_count = torch::empty(num_expert, ec_options);
......
......@@ -3,7 +3,7 @@ export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ]
then
python moe.py
python3 moe.py
elif [ .$1 = '.test_all' ]
then
for nexp in 1 2 4
......@@ -15,11 +15,11 @@ then
for bs in 4 16 64 256 512 1024 2048 4096
do
echo $bs $nexp ${inf}x${ouf}
python moe_test.py $bs $inf $ouf $nexp
python3 moe_test.py $bs $inf $ouf $nexp
done
done
done
done
else
python $@
python3 $@
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment