Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
468d761b
Unverified
Commit
468d761b
authored
Apr 23, 2024
by
Woosuk Kwon
Committed by
GitHub
Apr 23, 2024
Browse files
[Misc] Reduce supported Punica dtypes (#4304)
parent
e4bf860a
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
66 additions
and
72 deletions
+66
-72
CMakeLists.txt
CMakeLists.txt
+0
-12
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
+0
-4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
+0
-4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
+0
-4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
+0
-4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
+0
-4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
+0
-4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
+0
-4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
+0
-4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
+0
-4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
+0
-4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
+0
-4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
+0
-4
csrc/punica/bgmv/generator.py
csrc/punica/bgmv/generator.py
+20
-0
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+17
-0
tests/lora/test_layers.py
tests/lora/test_layers.py
+29
-12
No files found.
CMakeLists.txt
View file @
468d761b
...
@@ -212,23 +212,11 @@ define_gpu_extension_target(
...
@@ -212,23 +212,11 @@ define_gpu_extension_target(
set
(
VLLM_PUNICA_EXT_SRC
set
(
VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
"csrc/punica/punica_ops.cc"
)
"csrc/punica/punica_ops.cc"
)
#
#
...
...
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
deleted
100644 → 0
View file @
e4bf860a
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
float
,
nv_half
)
csrc/punica/bgmv/generator.py
View file @
468d761b
...
@@ -18,6 +18,26 @@ for input_dtype in DTYPES:
...
@@ -18,6 +18,26 @@ for input_dtype in DTYPES:
if
weight_dtype
==
"fp32"
:
if
weight_dtype
==
"fp32"
:
# FP32 weights are not supported.
# FP32 weights are not supported.
continue
continue
if
output_dtype
==
"fp32"
:
# LoRA A matrix.
if
input_dtype
!=
weight_dtype
:
# NOTE(woosuk): While Punica supports the case where the
# input and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif
input_dtype
==
"fp32"
:
# LoRA B matrix.
if
output_dtype
!=
weight_dtype
:
# NOTE(woosuk): While Punica supports the case where the
# output and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif
not
(
input_dtype
==
output_dtype
==
weight_dtype
):
# NOTE(woosuk): While Punica supports mixed data types for
# input, output, and weight, we only generate the kernels with
# the same data types to reduce the binary size.
continue
kernel_definition
=
TEMPLATE
.
format
(
kernel_definition
=
TEMPLATE
.
format
(
input_dtype
=
DTYPE_MAP
[
input_dtype
],
input_dtype
=
DTYPE_MAP
[
input_dtype
],
output_dtype
=
DTYPE_MAP
[
output_dtype
],
output_dtype
=
DTYPE_MAP
[
output_dtype
],
...
...
csrc/punica/punica_ops.cc
View file @
468d761b
...
@@ -50,6 +50,23 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
...
@@ -50,6 +50,23 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
int64_t
layer_idx
,
float
scale
)
{
// NOTE(woosuk): While Punica supports various combinations of input/output
// data types, we limit the supported data types to reduce the binary size.
constexpr
bool
is_input_float
=
std
::
is_same
<
in_T
,
float
>::
value
;
constexpr
bool
is_output_float
=
std
::
is_same
<
out_T
,
float
>::
value
;
if
(
is_input_float
)
{
if
(
!
std
::
is_same
<
out_T
,
W_T
>::
value
)
{
return
false
;
}
}
else
if
(
is_output_float
)
{
if
(
!
std
::
is_same
<
in_T
,
W_T
>::
value
)
{
return
false
;
}
}
else
if
(
!
(
std
::
is_same
<
in_T
,
W_T
>::
value
&&
std
::
is_same
<
out_T
,
W_T
>::
value
))
{
return
false
;
}
switch
(
pack_u32
(
in_features
,
out_features
))
{
switch
(
pack_u32
(
in_features
,
out_features
))
{
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u32(feat_in, feat_out): \
case pack_u32(feat_in, feat_out): \
...
...
tests/lora/test_layers.py
View file @
468d761b
...
@@ -413,7 +413,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
...
@@ -413,7 +413,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
def
_pretest
():
def
_pretest
():
linear
=
ParallelLMHead
(
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
linear
=
ParallelLMHead
(
vocab_size
+
lora_config
.
lora_extra_vocab_size
,
1024
,
vocab_size
)
1024
,
vocab_size
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
[:,
vocab_size
:]
=
0
linear
.
weight
.
data
[:,
vocab_size
:]
=
0
logits_processor
=
LogitsProcessor
(
logits_processor
=
LogitsProcessor
(
...
@@ -445,7 +447,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
...
@@ -445,7 +447,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
num_inputs
=
8
*
num_loras
,
# * 3,
num_inputs
=
8
*
num_loras
,
# * 3,
input_size
=
(
1
,
1024
),
input_size
=
(
1
,
1024
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float
32
,
input_type
=
torch
.
float
16
,
)
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
...
@@ -494,7 +496,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
...
@@ -494,7 +496,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
num_inputs
=
8
*
num_loras
*
3
,
num_inputs
=
8
*
num_loras
*
3
,
input_size
=
(
1
,
1024
),
input_size
=
(
1
,
1024
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float
32
,
input_type
=
torch
.
float
16
,
)
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
...
@@ -533,11 +535,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
...
@@ -533,11 +535,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
def
create_random_linear_parallel_layer
():
def
create_random_linear_parallel_layer
():
if
orientation
==
"row"
:
if
orientation
==
"row"
:
linear
=
RowParallelLinear
(
4096
,
4096
,
bias
=
False
)
linear
=
RowParallelLinear
(
4096
,
4096
,
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
RowParallelLinearWithLoRA
(
linear
)
lora_linear
=
RowParallelLinearWithLoRA
(
linear
)
else
:
else
:
linear
=
ColumnParallelLinear
(
4096
,
4096
,
bias
=
False
)
linear
=
ColumnParallelLinear
(
4096
,
4096
,
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
ColumnParallelLinearWithLoRA
(
linear
)
lora_linear
=
ColumnParallelLinearWithLoRA
(
linear
)
lora_linear
.
create_lora_weights
(
max_loras
,
lora_config
)
lora_linear
.
create_lora_weights
(
max_loras
,
lora_config
)
...
@@ -561,7 +569,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
...
@@ -561,7 +569,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
num_inputs
=
32
*
num_loras
,
num_inputs
=
32
*
num_loras
,
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float
32
,
input_type
=
torch
.
float
16
,
)
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
...
@@ -600,7 +608,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
...
@@ -600,7 +608,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
num_inputs
=
32
*
num_loras
,
num_inputs
=
32
*
num_loras
,
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float
32
,
input_type
=
torch
.
float
16
,
)
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
...
@@ -633,15 +641,24 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
...
@@ -633,15 +641,24 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
def
create_column_parallel_packed_layer
():
def
create_column_parallel_packed_layer
():
if
repeats
==
2
:
if
repeats
==
2
:
linear
=
MergedColumnParallelLinear
(
4096
,
[
4096
]
*
repeats
,
linear
=
MergedColumnParallelLinear
(
4096
,
[
4096
]
*
repeats
,
bias
=
False
)
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
MergedColumnParallelLinearWithLoRA
(
linear
)
lora_linear
=
MergedColumnParallelLinearWithLoRA
(
linear
)
elif
repeats
==
3
:
elif
repeats
==
3
:
linear
=
QKVParallelLinear
(
4096
,
64
,
32
,
bias
=
False
)
linear
=
QKVParallelLinear
(
4096
,
64
,
32
,
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
MergedQKVParallelLinearWithLora
(
linear
)
lora_linear
=
MergedQKVParallelLinearWithLora
(
linear
)
else
:
else
:
linear
=
QKVParallelLinear
(
4096
,
64
,
32
,
bias
=
False
)
linear
=
QKVParallelLinear
(
4096
,
64
,
32
,
bias
=
False
,
params_dtype
=
torch
.
float16
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
lora_linear
=
QKVParallelLinearWithLora
(
linear
)
lora_linear
=
QKVParallelLinearWithLora
(
linear
)
...
@@ -676,7 +693,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
...
@@ -676,7 +693,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
num_inputs
=
32
*
num_loras
,
num_inputs
=
32
*
num_loras
,
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float
32
,
input_type
=
torch
.
float
16
,
)
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
...
@@ -716,7 +733,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
...
@@ -716,7 +733,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
num_inputs
=
32
*
num_loras
,
num_inputs
=
32
*
num_loras
,
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float
32
,
input_type
=
torch
.
float
16
,
)
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment