Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0e63494c
Unverified
Commit
0e63494c
authored
Jul 24, 2024
by
Antoni Baum
Committed by
GitHub
Jul 24, 2024
Browse files
Add fp8 support to `reshape_and_cache_flash` (#6667)
parent
ee812580
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
98 additions
and
43 deletions
+98
-43
csrc/cache.h
csrc/cache.h
+2
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+45
-30
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+34
-8
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+2
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+2
-0
vllm/utils.py
vllm/utils.py
+7
-2
No files found.
csrc/cache.h
View file @
0e63494c
...
...
@@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
);
const
std
::
string
&
kv_cache_dtype
,
const
double
k_scale
,
const
double
v_scale
);
// Just for unittest
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
...
...
csrc/cache_kernels.cu
View file @
0e63494c
...
...
@@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
reshape_and_cache_flash_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
s
ca
lar
_t
*
__restrict__
k_cache
,
// [num_blocks, block_size, num_heads,
ca
che
_t
*
__restrict__
k
ey
_cache
,
// [num_blocks, block_size, num_heads,
// head_size]
s
ca
lar
_t
*
__restrict__
v_cache
,
// [num_blocks, block_size, num_heads,
ca
che
_t
*
__restrict__
v
alue
_cache
,
// [num_blocks, block_size, num_heads,
// head_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
)
{
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
float
k_scale
,
const
float
v_scale
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
// NOTE: slot_idx can be -1 if the token is padded
...
...
@@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
const
int64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int64_t
tgt_value_idx
=
block_idx
*
block_stride
+
block_offset
*
num_heads
*
head_size
+
head_idx
*
head_size
+
head_offset
;
k_cache
[
tgt_value_idx
]
=
key
[
src_key_idx
];
v_cache
[
tgt_value_idx
]
=
value
[
src_value_idx
];
const
int64_t
tgt_key_value_idx
=
block_idx
*
block_stride
+
block_offset
*
num_heads
*
head_size
+
head_idx
*
head_size
+
head_offset
;
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
key_cache
[
tgt_key_value_idx
]
=
tgt_key
;
value_cache
[
tgt_key_value_idx
]
=
tgt_value
;
}
else
{
key_cache
[
tgt_key_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
k_scale
);
value_cache
[
tgt_key_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_value
,
v_scale
);
}
}
}
}
// namespace vllm
...
...
@@ -278,40 +288,45 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE
)
}
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
k_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
v_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
value_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
kv_cache_dtype
)
{
// FIXME: only support auto datatype, does not support fp8
if
(
kv_cache_dtype
!=
"auto"
)
{
TORCH_CHECK
(
false
,
"Unsupported data type of kv cache: "
,
kv_cache_dtype
);
}
const
std
::
string
&
kv_cache_dtype
,
const
double
k_scale
,
const
double
v_scale
)
{
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
k_cache
.
size
(
1
);
int
block_size
=
k
ey
_cache
.
size
(
1
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
int
block_stride
=
k_cache
.
stride
(
0
);
TORCH_CHECK
(
k_cache
.
stride
(
0
)
==
v_cache
.
stride
(
0
));
int
block_stride
=
k
ey
_cache
.
stride
(
0
);
TORCH_CHECK
(
k
ey
_cache
.
stride
(
0
)
==
v
alue
_cache
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"reshape_and_cache_flash"
,
[
&
]
{
vllm
::
reshape_and_cache_flash_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
k_cache
.
data_ptr
<
scalar_t
>
(),
v_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
block_stride
,
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
);
});
DISPATCH_BY_KV_CACHE_DTYPE
(
key
.
dtype
(),
kv_cache_dtype
,
CALL_RESHAPE_AND_CACHE_FLASH
);
}
namespace
vllm
{
...
...
csrc/torch_bindings.cpp
View file @
0e63494c
...
...
@@ -248,7 +248,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache,"
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()"
);
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()"
);
cache_ops
.
impl
(
"reshape_and_cache_flash"
,
torch
::
kCUDA
,
&
reshape_and_cache_flash
);
...
...
tests/kernels/test_cache.py
View file @
0e63494c
...
...
@@ -215,8 +215,6 @@ def test_reshape_and_cache_flash(
device
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
if
kv_cache_dtype
==
"fp8"
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
...
...
@@ -248,15 +246,33 @@ def test_reshape_and_cache_flash(
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
].
contiguous
(
),
value_caches
[
0
].
contiguous
()
del
key_caches
del
value_caches
# Clone the KV caches.
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
)
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
...
...
@@ -269,8 +285,18 @@ def test_reshape_and_cache_flash(
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
allclose
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
allclose
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
...
vllm/_custom_ops.py
View file @
0e63494c
...
...
@@ -426,10 +426,13 @@ def reshape_and_cache_flash(
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
kv_cache_dtype
,
k_scale
,
v_scale
)
def
copy_blocks
(
key_caches
:
List
[
torch
.
Tensor
],
...
...
vllm/attention/backends/flash_attn.py
View file @
0e63494c
...
...
@@ -478,6 +478,8 @@ class FlashAttentionImpl(AttentionImpl):
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
...
...
vllm/attention/backends/flashinfer.py
View file @
0e63494c
...
...
@@ -489,6 +489,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
query
=
query
.
contiguous
(
...
...
vllm/utils.py
View file @
0e63494c
...
...
@@ -491,7 +491,6 @@ def create_kv_caches_with_random_flash(
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
assert
cache_dtype
!=
"fp8"
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
...
...
@@ -507,7 +506,13 @@ def create_kv_caches_with_random_flash(
key_value_cache
=
torch
.
empty
(
size
=
key_value_cache_shape
,
dtype
=
torch_dtype
,
device
=
device
)
key_value_cache
.
uniform_
(
-
scale
,
scale
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
key_value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
key_value_cache
,
-
scale
,
scale
)
else
:
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
key_caches
.
append
(
key_value_cache
[:,
0
])
value_caches
.
append
(
key_value_cache
[:,
1
])
return
key_caches
,
value_caches
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment