Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e070829a
Unverified
Commit
e070829a
authored
May 03, 2023
by
Woosuk Kwon
Committed by
GitHub
May 03, 2023
Browse files
Support bfloat16 data type (#54)
parent
436e523b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
455 additions
and
53 deletions
+455
-53
cacheflow/master/server.py
cacheflow/master/server.py
+2
-2
cacheflow/models/utils.py
cacheflow/models/utils.py
+1
-0
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+3
-1
csrc/attention/attention_dtypes.h
csrc/attention/attention_dtypes.h
+4
-0
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+6
-2
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+1
-1
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+361
-0
csrc/attention/dtype_float32.cuh
csrc/attention/dtype_float32.cuh
+1
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+55
-43
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+3
-1
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+3
-1
setup.py
setup.py
+15
-1
No files found.
cacheflow/master/server.py
View file @
e070829a
...
@@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
...
@@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# NOTE(woosuk):
If
FlashAttention
is used, the float data type i
s not support
ed
.
# NOTE(woosuk): FlashAttention
doe
s not support
float32
.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
],
help
=
'data type'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'bfloat16'
],
help
=
'data type'
)
# Parallel arguments
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
...
...
cacheflow/models/utils.py
View file @
e070829a
...
@@ -17,6 +17,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -17,6 +17,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
'float'
:
torch
.
float
,
'float'
:
torch
.
float
,
'float16'
:
torch
.
float16
,
'float16'
:
torch
.
float16
,
'float32'
:
torch
.
float32
,
'float32'
:
torch
.
float32
,
'bfloat16'
:
torch
.
bfloat16
,
}
}
...
...
csrc/activation_kernels.cu
View file @
e070829a
...
@@ -34,7 +34,9 @@ void silu_and_mul(
...
@@ -34,7 +34,9 @@ void silu_and_mul(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
dim3
block
(
std
::
min
(
d
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
"silu_and_mul_kernel"
,
[
&
]
{
[
&
]
{
...
...
csrc/attention/attention_dtypes.
cu
h
→
csrc/attention/attention_dtypes.h
View file @
e070829a
...
@@ -3,3 +3,7 @@
...
@@ -3,3 +3,7 @@
#include "attention_generic.cuh"
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_float32.cuh"
#ifdef ENABLE_BF16
#include "dtype_bfloat16.cuh"
#endif // ENABLE_BF16
csrc/attention/attention_kernels.cu
View file @
e070829a
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "attention_dtypes.
cu
h"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_utils.cuh"
#include <algorithm>
#include <algorithm>
...
@@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
...
@@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
context_lens
,
// [num_seqs]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int
block_size
,
int
block_size
,
int
max_context_len
)
{
int
max_context_len
)
{
// TODO(woosuk): Support FP32
and BF16
.
// TODO(woosuk): Support FP32.
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
#ifdef ENABLE_BF16
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
#endif
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
...
...
csrc/attention/attention_utils.cuh
View file @
e070829a
#pragma once
#pragma once
#include "attention_dtypes.
cu
h"
#include "attention_dtypes.h"
#include <float.h>
#include <float.h>
#include <type_traits>
#include <type_traits>
...
...
csrc/attention/dtype_bfloat16.cuh
0 → 100644
View file @
e070829a
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <stdint.h>
namespace
cacheflow
{
// Define custom BF16 vector data types.
struct
bf16_4_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
};
struct
bf16_8_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
__nv_bfloat162
z
;
__nv_bfloat162
w
;
};
// BF16 vector types for Q, K, V.
template
<
>
struct
Vec
<
__nv_bfloat16
,
1
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
Vec
<
__nv_bfloat16
,
2
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
Vec
<
__nv_bfloat16
,
4
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
Vec
<
__nv_bfloat16
,
8
>
{
using
Type
=
bf16_8_t
;
};
// FP32 accumulator vector types corresponding to Vec.
template
<
>
struct
FloatVec
<
__nv_bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
FloatVec
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
FloatVec
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
FloatVec
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
// Utility functions for type conversions.
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
return
__bfloat1622float2
(
val
);
}
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
return
__bfloat162bfloat162
(
val
);
}
// Vector addition.
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
return
a
+
b
;
}
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
__hadd2
(
a
,
b
);
}
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
inline
__device__
bf16_8_t
add
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
inline
__device__
float2
add
(
__nv_bfloat162
a
,
float2
fb
)
{
float2
fa
=
bf1622float2
(
a
);
return
add
(
fa
,
fb
);
}
inline
__device__
Float4_
add
(
bf16_4_t
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
inline
__device__
Float8_
add
(
bf16_8_t
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
// Vector multiplication.
template
<
>
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
return
__hmul
(
a
,
b
);
}
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
__hmul2
(
a
,
b
);
}
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
template
<
>
inline
__device__
bf16_4_t
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
c
;
}
template
<
>
inline
__device__
bf16_4_t
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
c
;
}
template
<
>
inline
__device__
bf16_8_t
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
c
;
}
template
<
>
inline
__device__
bf16_8_t
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
c
;
}
template
<
>
inline
__device__
float
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
float
fa
=
__bfloat162float
(
a
);
float
fb
=
__bfloat162float
(
b
);
return
fa
*
fb
;
}
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
template
<
>
inline
__device__
Float4_
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
template
<
>
inline
__device__
Float4_
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
fc
;
}
template
<
>
inline
__device__
Float8_
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
template
<
>
inline
__device__
Float8_
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
fc
;
}
// Vector fused multiply-add.
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
__hfma2
(
a
,
b
,
c
);
}
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
__hfma2
(
bf162bf162
(
a
),
b
,
c
);
}
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
bf16_4_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
bf16_4_t
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
bf16_8_t
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
bf16_8_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
bf16_8_t
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
float
fma
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
float
fc
)
{
return
__bfloat162float
(
a
)
*
__bfloat162float
(
b
)
+
fc
;
}
inline
__device__
float2
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
float2
fc
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
inline
__device__
float2
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
float2
fc
)
{
return
fma
(
bf162bf162
(
a
),
b
,
fc
);
}
inline
__device__
Float4_
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
inline
__device__
Float4_
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
Float4_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
inline
__device__
Float8_
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
inline
__device__
Float8_
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
Float8_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
// Vector sum.
template
<
>
inline
__device__
float
sum
(
__nv_bfloat16
v
)
{
return
__bfloat162float
(
v
);
}
template
<
>
inline
__device__
float
sum
(
__nv_bfloat162
v
)
{
float2
vf
=
bf1622float2
(
v
);
return
vf
.
x
+
vf
.
y
;
}
template
<
>
inline
__device__
float
sum
(
bf16_4_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
);
}
template
<
>
inline
__device__
float
sum
(
bf16_8_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
)
+
sum
(
v
.
z
)
+
sum
(
v
.
w
);
}
// From float32 to bfloat16.
inline
__device__
void
from_float
(
__nv_bfloat16
&
dst
,
float
src
)
{
dst
=
__float2bfloat16
(
src
);
}
inline
__device__
void
from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
dst
=
__float22bfloat162_rn
(
src
);
}
inline
__device__
void
from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
}
inline
__device__
void
from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
}
}
// namespace cacheflow
csrc/attention/dtype_float32.cuh
View file @
e070829a
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
namespace
cacheflow
{
namespace
cacheflow
{
// Define FP32 vector data types.
// Define
custom
FP32 vector data types.
struct
Float4_
{
struct
Float4_
{
float2
x
;
float2
x
;
float2
y
;
float2
y
;
...
...
csrc/cache_kernels.cu
View file @
e070829a
...
@@ -14,14 +14,16 @@ void swap_blocks(
...
@@ -14,14 +14,16 @@ void swap_blocks(
torch
::
Device
dst_device
=
dst
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
cudaMemcpyKind
memcpy_type
;
cudaMemcpyKind
memcpy_type
;
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
())
{
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
())
{
assert
(
src_device
.
index
()
==
dst_device
.
index
());
TORCH_CHECK
(
src_device
.
index
()
==
dst_device
.
index
(),
"src and dst must be on the same GPU"
);
memcpy_type
=
cudaMemcpyDeviceToDevice
;
memcpy_type
=
cudaMemcpyDeviceToDevice
;
}
else
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cpu
())
{
}
else
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cpu
())
{
memcpy_type
=
cudaMemcpyDeviceToHost
;
memcpy_type
=
cudaMemcpyDeviceToHost
;
}
else
if
(
src_device
.
is_cpu
()
&&
dst_device
.
is_cuda
())
{
}
else
if
(
src_device
.
is_cpu
()
&&
dst_device
.
is_cuda
())
{
memcpy_type
=
cudaMemcpyHostToDevice
;
memcpy_type
=
cudaMemcpyHostToDevice
;
}
else
{
}
else
{
assert
(
false
);
TORCH_CHECK
(
false
,
"Invalid device combination"
);
}
}
void
*
src_ptr
=
src
.
data_ptr
();
void
*
src_ptr
=
src
.
data_ptr
();
...
@@ -29,6 +31,7 @@ void swap_blocks(
...
@@ -29,6 +31,7 @@ void swap_blocks(
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for
(
const
auto
&
pair
:
block_mapping
)
{
for
(
const
auto
&
pair
:
block_mapping
)
{
int64_t
src_block_number
=
pair
.
first
;
int64_t
src_block_number
=
pair
.
first
;
int64_t
dst_block_number
=
pair
.
second
;
int64_t
dst_block_number
=
pair
.
second
;
...
@@ -122,7 +125,9 @@ void copy_blocks(
...
@@ -122,7 +125,9 @@ void copy_blocks(
dim3
grid
(
num_layers
,
num_pairs
);
dim3
grid
(
num_layers
,
num_pairs
);
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
cacheflow
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
cacheflow
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
...
@@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel(
...
@@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel(
}
}
}
}
}
// namespace cacheflow
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
)
// [num_tokens]
{
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key
.
scalar_type
(),
"reshape_and_cache_kernel"
,
[
&
]
{
cacheflow
::
reshape_and_cache_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
});
}
namespace
cacheflow
{
// Grid: (num_blocks, block_size).
// Grid: (num_blocks, block_size).
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
gather_cached_kv_kernel
(
__global__
void
gather_cached_kv_kernel
(
...
@@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized(
...
@@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized(
}
// namespace cacheflow
}
// namespace cacheflow
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
)
// [num_tokens]
{
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
key
.
scalar_type
(),
"reshape_and_cache_kernel"
,
[
&
]
{
cacheflow
::
reshape_and_cache_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
});
}
void
gather_cached_kv
(
void
gather_cached_kv
(
torch
::
Tensor
&
key
,
// [out] [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [out] [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [out] [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [out] [num_tokens, num_heads, head_size]
...
@@ -354,7 +364,9 @@ void gather_cached_kv(
...
@@ -354,7 +364,9 @@ void gather_cached_kv(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key
.
scalar_type
(),
key
.
scalar_type
(),
"gather_cached_kv_kernel_optimized"
,
"gather_cached_kv_kernel_optimized"
,
[
&
]
{
[
&
]
{
...
...
csrc/layernorm_kernels.cu
View file @
e070829a
...
@@ -46,7 +46,9 @@ void rms_norm(
...
@@ -46,7 +46,9 @@ void rms_norm(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
input
.
scalar_type
(),
"rms_norm_kernel"
,
"rms_norm_kernel"
,
[
&
]
{
[
&
]
{
...
...
csrc/pos_encoding_kernels.cu
View file @
e070829a
...
@@ -64,7 +64,9 @@ void rotary_embedding_neox(
...
@@ -64,7 +64,9 @@ void rotary_embedding_neox(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
query
.
scalar_type
(),
query
.
scalar_type
(),
"rotary_embedding_neox"
,
"rotary_embedding_neox"
,
[
&
]
{
[
&
]
{
...
...
setup.py
View file @
e070829a
import
setuptools
import
setuptools
import
torch
from
torch.utils
import
cpp_extension
from
torch.utils
import
cpp_extension
CXX_FLAGS
=
[
'-g'
]
CXX_FLAGS
=
[
'-g'
]
NVCC_FLAGS
=
[
'-O2'
]
NVCC_FLAGS
=
[
'-O2'
]
if
not
torch
.
cuda
.
is_available
():
raise
RuntimeError
(
f
'Cannot find CUDA at CUDA_HOME:
{
cpp_extension
.
CUDA_HOME
}
. '
'CUDA must be available in order to build the package.'
)
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
# different compute capabilities.
compute_capability
=
torch
.
cuda
.
get_device_capability
()
major
,
minor
=
compute_capability
# Enable bfloat16 support if the compute capability is >= 8.0.
if
major
>=
8
:
NVCC_FLAGS
.
append
(
'-DENABLE_BF16'
)
ext_modules
=
[]
ext_modules
=
[]
...
@@ -23,7 +36,7 @@ attention_extension = cpp_extension.CUDAExtension(
...
@@ -23,7 +36,7 @@ attention_extension = cpp_extension.CUDAExtension(
)
)
ext_modules
.
append
(
attention_extension
)
ext_modules
.
append
(
attention_extension
)
# Positional encodings.
# Positional encoding
kernel
s.
positional_encoding_extension
=
cpp_extension
.
CUDAExtension
(
positional_encoding_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.pos_encoding_ops'
,
name
=
'cacheflow.pos_encoding_ops'
,
sources
=
[
'csrc/pos_encoding.cpp'
,
'csrc/pos_encoding_kernels.cu'
],
sources
=
[
'csrc/pos_encoding.cpp'
,
'csrc/pos_encoding_kernels.cu'
],
...
@@ -39,6 +52,7 @@ layernorm_extension = cpp_extension.CUDAExtension(
...
@@ -39,6 +52,7 @@ layernorm_extension = cpp_extension.CUDAExtension(
)
)
ext_modules
.
append
(
layernorm_extension
)
ext_modules
.
append
(
layernorm_extension
)
# Activation kernels.
activation_extension
=
cpp_extension
.
CUDAExtension
(
activation_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.activation_ops'
,
name
=
'cacheflow.activation_ops'
,
sources
=
[
'csrc/activation.cpp'
,
'csrc/activation_kernels.cu'
],
sources
=
[
'csrc/activation.cpp'
,
'csrc/activation_kernels.cu'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment