Unverified Commit 2cae2907 authored by Casper's avatar Casper Committed by GitHub
Browse files

Add device guard (fix multi-GPU) (#10)

parent bad253e6
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
...@@ -75,6 +76,10 @@ void moe_alig_block_size( ...@@ -75,6 +76,10 @@ void moe_alig_block_size(
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) { torch::Tensor num_tokens_post_pad) {
const at::cuda::OptionalCUDAGuard device_guard_topk_ids(device_of(topk_ids));
const at::cuda::OptionalCUDAGuard device_guard_sorted(device_of(sorted_token_ids));
const at::cuda::OptionalCUDAGuard device_guard_experts(device_of(experts_ids));
const at::cuda::OptionalCUDAGuard device_guard_num_tokens(device_of(num_tokens_post_pad));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
assert(num_experts <= NUM_MAX_EXPERTS); assert(num_experts <= NUM_MAX_EXPERTS);
VLLM_DISPATCH_INTEGRAL_TYPES( VLLM_DISPATCH_INTEGRAL_TYPES(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment