Commit 93291a7e authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

update

parent c5d719cf
......@@ -197,8 +197,7 @@ void moe_cuda_grad_weight(
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
cublasOperation_t transb) {
const size_t num_expert) {
Helper* h = getHelper(num_expert);
......@@ -207,7 +206,7 @@ void moe_cuda_grad_weight(
checkCudaErrors(cudaMemcpy(gate_host, gate, batch_size * sizeof(int), cudaMemcpyDeviceToHost));
for (size_t i=0; i<batch_size; ++i) {
checkCudaErrors(cublasSetStream(h->handle, *(h->streams + gate_host[i])));
checkCudaErrors(cublasSgemm(h->handle,
checkCudaErrors(cublasXgemm(h->handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
out_feat,
......@@ -283,6 +282,20 @@ std::vector<torch::Tensor> moe_cuda_backward(
CUBLAS_OP_N
);
}));
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_cuda_backward", ([&] {
moe_cuda_grad_weight<scalar_t>(
input.data_ptr<scalar_t>(),
gate.data_ptr<int>(),
grad_output.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
batch_size,
out_feat,
in_feat,
num_expert
);
}));
return {grad_input, grad_weight};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment