Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
989ae253
Unverified
Commit
989ae253
authored
Apr 13, 2024
by
Jee Li
Committed by
GitHub
Apr 13, 2024
Browse files
[Kernel] Add punica dimension for Baichuan-13B (#4053)
parent
0a430b4a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
1 deletion
+3
-1
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+1
-0
tests/lora/test_baichuan.py
tests/lora/test_baichuan.py
+1
-1
tests/lora/test_punica.py
tests/lora/test_punica.py
+1
-0
No files found.
csrc/punica/bgmv/bgmv_config.h
View file @
989ae253
...
@@ -47,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -47,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 13696) \
f(in_T, out_T, W_T, narrow, 13696) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 15360) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 22016) \
...
...
tests/lora/test_baichuan.py
View file @
989ae253
...
@@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files):
...
@@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files):
@
pytest
.
mark
.
skip
(
"Requires multiple GPUs"
)
@
pytest
.
mark
.
skip
(
"Requires multiple GPUs"
)
def
test_
llama
_tensor_parallel_equality
(
baichuan_lora_files
):
def
test_
baichuan
_tensor_parallel_equality
(
baichuan_lora_files
):
# Cannot use as it will initialize torch.cuda too early...
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
...
...
tests/lora/test_punica.py
View file @
989ae253
...
@@ -72,6 +72,7 @@ H1 = H2 = [
...
@@ -72,6 +72,7 @@ H1 = H2 = [
11008
,
11008
,
13824
,
13824
,
14336
,
14336
,
15360
,
22016
,
22016
,
24576
,
24576
,
27392
,
27392
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment