Unverified Commit 062f1a2d authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix compile error for `swap_blocks_batch` in CUDA 13 (#38915)

parent 81994e1d
...@@ -91,9 +91,9 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs, ...@@ -91,9 +91,9 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
if (n == 0) return; if (n == 0) return;
const int64_t* src_data = src_ptrs.data_ptr<int64_t>(); int64_t* src_data = src_ptrs.mutable_data_ptr<int64_t>();
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>(); int64_t* dst_data = dst_ptrs.mutable_data_ptr<int64_t>();
const int64_t* size_data = sizes.data_ptr<int64_t>(); int64_t* size_data = sizes.mutable_data_ptr<int64_t>();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -107,15 +107,24 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs, ...@@ -107,15 +107,24 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
CUmemcpyAttributes attr = {}; CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM; attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0; size_t attrs_idx = 0;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 13000
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(dst_data),
reinterpret_cast<CUdeviceptr*>(src_data),
reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), &attr,
&attrs_idx, 1, static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed with error ",
result);
#else
size_t fail_idx = 0; size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync( CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)), reinterpret_cast<CUdeviceptr*>(dst_data),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)), reinterpret_cast<CUdeviceptr*>(src_data),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)), reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), &attr,
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx, &attrs_idx, 1, &fail_idx, static_cast<CUstream>(stream));
static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ", TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result); fail_idx, " with error ", result);
#endif
#else #else
// Fallback for CUDA < 12.8 and ROCm: individual async copies. // Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types. // cudaMemcpyDefault lets the driver infer direction from pointer types.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment