Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
566b57c5
Unverified
Commit
566b57c5
authored
Mar 27, 2024
by
Jee Li
Committed by
GitHub
Mar 27, 2024
Browse files
[Kernel] support non-zero cuda devices in punica kernels (#3636)
parent
0dc72273
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
62 deletions
+29
-62
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+3
-1
tests/lora/test_punica.py
tests/lora/test_punica.py
+26
-61
No files found.
csrc/punica/punica_ops.cc
View file @
566b57c5
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
#include "bgmv/bgmv_config.h"
...
...
@@ -91,6 +91,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ
(
w
.
size
(
2
),
h_out
);
CHECK_EQ
(
indicies
.
size
(
0
),
x
.
size
(
0
));
CHECK_EQ
(
y
.
size
(
0
),
x
.
size
(
0
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
x
));
bool
ok
=
false
;
if
(
h_in
<
65536
&&
h_out
<
65536
)
{
// TODO: See if we can get rid of this massive nested switch
...
...
@@ -322,6 +323,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ
(
w
.
size
(
2
),
h_out
);
CHECK_EQ
(
indicies
.
size
(
0
),
x
.
size
(
0
));
CHECK_EQ
(
y
.
size
(
0
),
x
.
size
(
0
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
x
));
bool
ok
=
false
;
if
(
h_in
<
65536
&&
h_out
<
65536
)
{
// TODO: See if we can get rid of this massive nested switch
...
...
tests/lora/test_punica.py
View file @
566b57c5
...
...
@@ -49,14 +49,18 @@ H1 = H2 = [
32768
,
33024
]
SEED
=
[
0xabcdabcd987
]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
@
pytest
.
mark
.
parametrize
(
"dtype_str"
,
[
"float16"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"h1"
,
H1
)
@
pytest
.
mark
.
parametrize
(
"h2"
,
H2
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEED
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_lora_correctness
(
dtype_str
,
h1
,
h2
,
seed
):
def
test_lora_correctness
(
dtype_str
,
h1
,
h2
,
seed
,
device
):
torch
.
manual_seed
(
seed
)
num_loras
=
4
num_layers
=
1
...
...
@@ -64,25 +68,15 @@ def test_lora_correctness(dtype_str, h1, h2, seed):
bs
=
32
scale
=
0.123
dtype
=
getattr
(
torch
,
dtype_str
)
device
=
torch
.
device
(
"cuda"
)
wa_T_all
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
,
device
=
device
)
wb_T_all
=
torch
.
randn
(
num_loras
,
num_layers
,
h2
,
r
,
dtype
=
dtype
,
device
=
device
)
indices
=
torch
.
randint
(
num_loras
,
(
bs
,
),
dtype
=
torch
.
long
,
device
=
device
)
torch
.
set_default_device
(
device
)
wa_T_all
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
)
wb_T_all
=
torch
.
randn
(
num_loras
,
num_layers
,
h2
,
r
,
dtype
=
dtype
)
indices
=
torch
.
randint
(
num_loras
,
(
bs
,
),
dtype
=
torch
.
long
)
for
layer_idx
in
range
(
num_layers
):
x
=
torch
.
randn
(
bs
,
h1
,
dtype
=
dtype
,
device
=
device
)
y
=
torch
.
randn
(
bs
,
h2
,
dtype
=
dtype
,
device
=
device
)
x
=
torch
.
randn
(
bs
,
h1
,
dtype
=
dtype
)
y
=
torch
.
randn
(
bs
,
h2
,
dtype
=
dtype
)
y_ref
=
y
.
clone
()
_lora_ref_impl
(
y_ref
,
x
,
wa_T_all
,
wb_T_all
,
indices
,
layer_idx
,
scale
)
...
...
@@ -98,8 +92,9 @@ def test_lora_correctness(dtype_str, h1, h2, seed):
@
pytest
.
mark
.
parametrize
(
"h1"
,
H1
)
@
pytest
.
mark
.
parametrize
(
"h2"
,
H2
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEED
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_lora_correctness_slice
(
dtype_str
,
h1
,
h2
,
seed
):
def
test_lora_correctness_slice
(
dtype_str
,
h1
,
h2
,
seed
,
device
):
if
h2
%
3
!=
0
or
h2
//
3
not
in
H1
:
pytest
.
skip
(
"h2 must be divisible by 3 and in supported shapes"
)
torch
.
manual_seed
(
seed
)
...
...
@@ -109,50 +104,20 @@ def test_lora_correctness_slice(dtype_str, h1, h2, seed):
bs
=
32
scale
=
0.123
dtype
=
getattr
(
torch
,
dtype_str
)
device
=
torch
.
device
(
"cuda"
)
wa_T_all_0
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
,
device
=
device
)
wa_T_all_1
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
,
device
=
device
)
wa_T_all_2
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
,
device
=
device
)
wb_T_all_0
=
torch
.
randn
(
num_loras
,
num_layers
,
h2
//
3
,
r
,
dtype
=
dtype
,
device
=
device
)
wb_T_all_1
=
torch
.
randn
(
num_loras
,
num_layers
,
h2
//
3
,
r
,
dtype
=
dtype
,
device
=
device
)
wb_T_all_2
=
torch
.
randn
(
num_loras
,
num_layers
,
h2
//
3
,
r
,
dtype
=
dtype
,
device
=
device
)
indices
=
torch
.
randint
(
num_loras
,
(
bs
,
),
dtype
=
torch
.
long
,
device
=
device
)
torch
.
set_default_device
(
device
)
wa_T_all_0
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
)
wa_T_all_1
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
)
wa_T_all_2
=
torch
.
randn
(
num_loras
,
num_layers
,
r
,
h1
,
dtype
=
dtype
)
wb_T_all_0
=
torch
.
randn
(
num_loras
,
num_layers
,
h2
//
3
,
r
,
dtype
=
dtype
)
wb_T_all_1
=
torch
.
randn
(
num_loras
,
num_layers
,
h2
//
3
,
r
,
dtype
=
dtype
)
wb_T_all_2
=
torch
.
randn
(
num_loras
,
num_layers
,
h2
//
3
,
r
,
dtype
=
dtype
)
indices
=
torch
.
randint
(
num_loras
,
(
bs
,
),
dtype
=
torch
.
long
)
for
layer_idx
in
range
(
num_layers
):
x
=
torch
.
randn
(
bs
,
h1
,
dtype
=
dtype
,
device
=
device
)
y
=
torch
.
randn
(
bs
,
h2
,
dtype
=
dtype
,
device
=
device
)
x
=
torch
.
randn
(
bs
,
h1
,
dtype
=
dtype
)
y
=
torch
.
randn
(
bs
,
h2
,
dtype
=
dtype
)
s
=
h2
//
3
y_ref
=
y
.
clone
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment