"vllm/vscode:/vscode.git/clone" did not exist on "377061d481f377d8e9b11a1951fdfcdc1159a72c"
torch_utils.h 527 Bytes
Newer Older
1
2
3
#pragma once

#include <torch/csrc/inductor/aoti_torch/c/shim.h>
4
5
#include <torch/headeronly/util/shim_utils.h>

6
7
8
9
#include <cuda_runtime.h>

// Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches.
10
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
11
12
13
14
15
  void* stream_ptr = nullptr;
  TORCH_ERROR_CODE_CHECK(
      aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
  return reinterpret_cast<cudaStream_t>(stream_ptr);
}