Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
59fff4a0
Unverified
Commit
59fff4a0
authored
Feb 10, 2025
by
youkaichao
Committed by
GitHub
Feb 10, 2025
Browse files
[core] improve error handling when wake up from sleep mode (#12981)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
29f1d47e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
78 additions
and
12 deletions
+78
-12
csrc/cumem_allocator.cpp
csrc/cumem_allocator.cpp
+51
-12
tests/basic_correctness/test_cumem.py
tests/basic_correctness/test_cumem.py
+27
-0
No files found.
csrc/cumem_allocator.cpp
View file @
59fff4a0
...
...
@@ -12,15 +12,21 @@ extern "C" {
#include <cuda_runtime_api.h>
#include <cuda.h>
#define CUDA_CHECK(condition) \
do { \
CUresult error = condition; \
if (error != 0) { \
char* error_string; \
cuGetErrorString(error, (const char**)&error_string); \
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \
<< __LINE__ << std::endl; \
} \
char
error_msg
[
10240
];
// 10KB buffer to store error messages
CUresult
no_error
=
CUresult
(
0
);
CUresult
error_code
=
no_error
;
// store error code
#define CUDA_CHECK(condition) \
do { \
CUresult error = condition; \
if (error != 0) { \
error_code = error; \
char* error_string; \
cuGetErrorString(error, (const char**)&error_string); \
snprintf(error_msg, sizeof(error_msg), "CUDA Error: %s at %s:%d", \
error_string, __FILE__, __LINE__); \
std::cerr << error_msg << std::endl; \
} \
} while (0)
// Global references to Python callables
...
...
@@ -54,14 +60,22 @@ void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
// Allocate memory using cuMemCreate
CUDA_CHECK
(
cuMemCreate
(
p_memHandle
,
size
,
&
prop
,
0
));
if
(
error_code
!=
0
)
{
return
;
}
CUDA_CHECK
(
cuMemMap
(
d_mem
,
size
,
0
,
*
p_memHandle
,
0
));
if
(
error_code
!=
0
)
{
return
;
}
CUmemAccessDesc
accessDesc
=
{};
accessDesc
.
location
.
type
=
CU_MEM_LOCATION_TYPE_DEVICE
;
accessDesc
.
location
.
id
=
device
;
accessDesc
.
flags
=
CU_MEM_ACCESS_FLAGS_PROT_READWRITE
;
CUDA_CHECK
(
cuMemSetAccess
(
d_mem
,
size
,
&
accessDesc
,
1
));
if
(
error_code
!=
0
)
{
return
;
}
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
}
...
...
@@ -73,7 +87,13 @@ void unmap_and_release(unsigned long long device, ssize_t size,
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
ensure_context
(
device
);
CUDA_CHECK
(
cuMemUnmap
(
d_mem
,
size
));
if
(
error_code
!=
0
)
{
return
;
}
CUDA_CHECK
(
cuMemRelease
(
*
p_memHandle
));
if
(
error_code
!=
0
)
{
return
;
}
}
PyObject
*
create_tuple_from_c_integers
(
unsigned
long
long
a
,
...
...
@@ -121,12 +141,16 @@ void* my_malloc(ssize_t size, int device, CUstream stream) {
size_t
granularity
;
CUDA_CHECK
(
cuMemGetAllocationGranularity
(
&
granularity
,
&
prop
,
CU_MEM_ALLOC_GRANULARITY_MINIMUM
));
if
(
error_code
!=
0
)
{
return
nullptr
;
}
size_t
alignedSize
=
((
size
+
granularity
-
1
)
/
granularity
)
*
granularity
;
CUdeviceptr
d_mem
;
CUDA_CHECK
(
cuMemAddressReserve
(
&
d_mem
,
alignedSize
,
0
,
0
,
0
));
if
(
error_code
!=
0
)
{
return
nullptr
;
}
// allocate the CUmemGenericAllocationHandle
CUmemGenericAllocationHandle
*
p_memHandle
=
(
CUmemGenericAllocationHandle
*
)
malloc
(
...
...
@@ -208,6 +232,9 @@ void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
// free address and the handle
CUDA_CHECK
(
cuMemAddressFree
(
d_mem
,
size
));
if
(
error_code
!=
0
)
{
return
;
}
free
(
p_memHandle
);
}
...
...
@@ -258,6 +285,12 @@ static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
unmap_and_release
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
);
if
(
error_code
!=
0
)
{
error_code
=
no_error
;
PyErr_SetString
(
PyExc_RuntimeError
,
error_msg
);
return
nullptr
;
}
Py_RETURN_NONE
;
}
...
...
@@ -282,6 +315,12 @@ static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
create_and_map
(
recv_device
,
recv_size
,
d_mem_ptr
,
p_memHandle
);
if
(
error_code
!=
0
)
{
error_code
=
no_error
;
PyErr_SetString
(
PyExc_RuntimeError
,
error_msg
);
return
nullptr
;
}
Py_RETURN_NONE
;
}
...
...
tests/basic_correctness/test_cumem.py
View file @
59fff4a0
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
vllm
import
LLM
,
SamplingParams
...
...
@@ -9,6 +10,32 @@ from vllm.utils import GiB_bytes
from
..utils
import
fork_new_process_for_each_test
@
fork_new_process_for_each_test
def
test_python_error
():
"""
Test if Python error occurs when there's low-level
error happening from the C++ side.
"""
allocator
=
CuMemAllocator
.
get_instance
()
total_bytes
=
torch
.
cuda
.
mem_get_info
()[
1
]
alloc_bytes
=
int
(
total_bytes
*
0.7
)
tensors
=
[]
with
allocator
.
use_memory_pool
():
# allocate 70% of the total memory
x
=
torch
.
empty
(
alloc_bytes
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
tensors
.
append
(
x
)
# release the memory
allocator
.
sleep
()
# allocate more memory than the total memory
y
=
torch
.
empty
(
alloc_bytes
,
dtype
=
torch
.
uint8
,
device
=
'cuda'
)
tensors
.
append
(
y
)
with
pytest
.
raises
(
RuntimeError
):
# when the allocator is woken up, it should raise an error
# because we don't have enough memory
allocator
.
wake_up
()
@
fork_new_process_for_each_test
def
test_basic_cumem
():
# some tensors from default memory pool
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment