Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e070829a
"src/vscode:/vscode.git/clone" did not exist on "4a4cdd6b07a36bbf58643e96c9a16d3851ca5bc5"
Unverified
Commit
e070829a
authored
May 03, 2023
by
Woosuk Kwon
Committed by
GitHub
May 03, 2023
Browse files
Support bfloat16 data type (#54)
parent
436e523b
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
455 additions
and
53 deletions
+455
-53
cacheflow/master/server.py
cacheflow/master/server.py
+2
-2
cacheflow/models/utils.py
cacheflow/models/utils.py
+1
-0
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+3
-1
csrc/attention/attention_dtypes.h
csrc/attention/attention_dtypes.h
+4
-0
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+6
-2
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+1
-1
csrc/attention/dtype_bfloat16.cuh
csrc/attention/dtype_bfloat16.cuh
+361
-0
csrc/attention/dtype_float32.cuh
csrc/attention/dtype_float32.cuh
+1
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+55
-43
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+3
-1
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+3
-1
setup.py
setup.py
+15
-1
No files found.
cacheflow/master/server.py
View file @
e070829a
...
@@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
...
@@ -213,8 +213,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--use-np-cache'
,
action
=
'store_true'
,
help
=
'save a numpy copy of model weights for faster loading'
)
help
=
'save a numpy copy of model weights for faster loading'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
parser
.
add_argument
(
'--use-dummy-weights'
,
action
=
'store_true'
,
help
=
'use dummy values for model weights'
)
# NOTE(woosuk):
If
FlashAttention
is used, the float data type i
s not support
ed
.
# NOTE(woosuk): FlashAttention
doe
s not support
float32
.
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
],
help
=
'data type'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'half'
,
choices
=
[
'half'
,
'bfloat16'
],
help
=
'data type'
)
# Parallel arguments
# Parallel arguments
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--use-ray'
,
action
=
'store_true'
,
help
=
'use Ray for distributed serving, will be automatically set when using more than 1 GPU'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
parser
.
add_argument
(
'--pipeline-parallel-size'
,
'-pp'
,
type
=
int
,
default
=
1
,
help
=
'number of pipeline stages'
)
...
...
cacheflow/models/utils.py
View file @
e070829a
...
@@ -17,6 +17,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -17,6 +17,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
'float'
:
torch
.
float
,
'float'
:
torch
.
float
,
'float16'
:
torch
.
float16
,
'float16'
:
torch
.
float16
,
'float32'
:
torch
.
float32
,
'float32'
:
torch
.
float32
,
'bfloat16'
:
torch
.
bfloat16
,
}
}
...
...
csrc/activation_kernels.cu
View file @
e070829a
...
@@ -34,7 +34,9 @@ void silu_and_mul(
...
@@ -34,7 +34,9 @@ void silu_and_mul(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
dim3
block
(
std
::
min
(
d
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
"silu_and_mul_kernel"
,
[
&
]
{
[
&
]
{
...
...
csrc/attention/attention_dtypes.
cu
h
→
csrc/attention/attention_dtypes.h
View file @
e070829a
...
@@ -3,3 +3,7 @@
...
@@ -3,3 +3,7 @@
#include "attention_generic.cuh"
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_float32.cuh"
#ifdef ENABLE_BF16
#include "dtype_bfloat16.cuh"
#endif // ENABLE_BF16
csrc/attention/attention_kernels.cu
View file @
e070829a
#include <torch/extension.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "attention_dtypes.
cu
h"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_utils.cuh"
#include <algorithm>
#include <algorithm>
...
@@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
...
@@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
context_lens
,
// [num_seqs]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int
block_size
,
int
block_size
,
int
max_context_len
)
{
int
max_context_len
)
{
// TODO(woosuk): Support FP32
and BF16
.
// TODO(woosuk): Support FP32.
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
uint16_t
);
#ifdef ENABLE_BF16
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
__nv_bfloat16
);
#endif
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
query
.
dtype
());
}
}
...
...
csrc/attention/attention_utils.cuh
View file @
e070829a
#pragma once
#pragma once
#include "attention_dtypes.
cu
h"
#include "attention_dtypes.h"
#include <float.h>
#include <float.h>
#include <type_traits>
#include <type_traits>
...
...
csrc/attention/dtype_bfloat16.cuh
0 → 100644
View file @
e070829a
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <stdint.h>
namespace
cacheflow
{
// Define custom BF16 vector data types.
struct
bf16_4_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
};
struct
bf16_8_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
__nv_bfloat162
z
;
__nv_bfloat162
w
;
};
// BF16 vector types for Q, K, V.
template
<
>
struct
Vec
<
__nv_bfloat16
,
1
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
Vec
<
__nv_bfloat16
,
2
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
Vec
<
__nv_bfloat16
,
4
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
Vec
<
__nv_bfloat16
,
8
>
{
using
Type
=
bf16_8_t
;
};
// FP32 accumulator vector types corresponding to Vec.
template
<
>
struct
FloatVec
<
__nv_bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
FloatVec
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
FloatVec
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
FloatVec
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
// Utility functions for type conversions.
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
return
__bfloat1622float2
(
val
);
}
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
return
__bfloat162bfloat162
(
val
);
}
// Vector addition.
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
return
a
+
b
;
}
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
__hadd2
(
a
,
b
);
}
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
inline
__device__
bf16_8_t
add
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
inline
__device__
float2
add
(
__nv_bfloat162
a
,
float2
fb
)
{
float2
fa
=
bf1622float2
(
a
);
return
add
(
fa
,
fb
);
}
inline
__device__
Float4_
add
(
bf16_4_t
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
inline
__device__
Float8_
add
(
bf16_8_t
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
// Vector multiplication.
template
<
>
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
return
__hmul
(
a
,
b
);
}
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
__hmul2
(
a
,
b
);
}
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
template
<
>
inline
__device__
bf16_4_t
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
c
;
}
template
<
>
inline
__device__
bf16_4_t
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
c
;
}
template
<
>
inline
__device__
bf16_8_t
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
c
;
}
template
<
>
inline
__device__
bf16_8_t
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
c
;
}
template
<
>
inline
__device__
float
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
float
fa
=
__bfloat162float
(
a
);
float
fb
=
__bfloat162float
(
b
);
return
fa
*
fb
;
}
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
template
<
>
inline
__device__
Float4_
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
template
<
>
inline
__device__
Float4_
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
fc
;
}
template
<
>
inline
__device__
Float8_
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
template
<
>
inline
__device__
Float8_
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
fc
;
}
// Vector fused multiply-add.
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
__hfma2
(
a
,
b
,
c
);
}
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
__hfma2
(
bf162bf162
(
a
),
b
,
c
);
}
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
bf16_4_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
bf16_4_t
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
inline
__device__
bf16_8_t
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
bf16_8_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
bf16_8_t
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
inline
__device__
float
fma
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
float
fc
)
{
return
__bfloat162float
(
a
)
*
__bfloat162float
(
b
)
+
fc
;
}
inline
__device__
float2
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
float2
fc
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
inline
__device__
float2
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
float2
fc
)
{
return
fma
(
bf162bf162
(
a
),
b
,
fc
);
}
inline
__device__
Float4_
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
inline
__device__
Float4_
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
Float4_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
inline
__device__
Float8_
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
inline
__device__
Float8_
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
Float8_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
// Vector sum.
template
<
>
inline
__device__
float
sum
(
__nv_bfloat16
v
)
{
return
__bfloat162float
(
v
);
}
template
<
>
inline
__device__
float
sum
(
__nv_bfloat162
v
)
{
float2
vf
=
bf1622float2
(
v
);
return
vf
.
x
+
vf
.
y
;
}
template
<
>
inline
__device__
float
sum
(
bf16_4_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
);
}
template
<
>
inline
__device__
float
sum
(
bf16_8_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
)
+
sum
(
v
.
z
)
+
sum
(
v
.
w
);
}
// From float32 to bfloat16.
inline
__device__
void
from_float
(
__nv_bfloat16
&
dst
,
float
src
)
{
dst
=
__float2bfloat16
(
src
);
}
inline
__device__
void
from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
dst
=
__float22bfloat162_rn
(
src
);
}
inline
__device__
void
from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
}
inline
__device__
void
from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
}
}
// namespace cacheflow
csrc/attention/dtype_float32.cuh
View file @
e070829a
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
namespace
cacheflow
{
namespace
cacheflow
{
// Define FP32 vector data types.
// Define
custom
FP32 vector data types.
struct
Float4_
{
struct
Float4_
{
float2
x
;
float2
x
;
float2
y
;
float2
y
;
...
...
csrc/cache_kernels.cu
View file @
e070829a
...
@@ -14,14 +14,16 @@ void swap_blocks(
...
@@ -14,14 +14,16 @@ void swap_blocks(
torch
::
Device
dst_device
=
dst
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
cudaMemcpyKind
memcpy_type
;
cudaMemcpyKind
memcpy_type
;
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
())
{
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
())
{
assert
(
src_device
.
index
()
==
dst_device
.
index
());
TORCH_CHECK
(
src_device
.
index
()
==
dst_device
.
index
(),
"src and dst must be on the same GPU"
);
memcpy_type
=
cudaMemcpyDeviceToDevice
;
memcpy_type
=
cudaMemcpyDeviceToDevice
;
}
else
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cpu
())
{
}
else
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cpu
())
{
memcpy_type
=
cudaMemcpyDeviceToHost
;
memcpy_type
=
cudaMemcpyDeviceToHost
;
}
else
if
(
src_device
.
is_cpu
()
&&
dst_device
.
is_cuda
())
{
}
else
if
(
src_device
.
is_cpu
()
&&
dst_device
.
is_cuda
())
{
memcpy_type
=
cudaMemcpyHostToDevice
;
memcpy_type
=
cudaMemcpyHostToDevice
;
}
else
{
}
else
{
assert
(
false
);
TORCH_CHECK
(
false
,
"Invalid device combination"
);
}
}
void
*
src_ptr
=
src
.
data_ptr
();
void
*
src_ptr
=
src
.
data_ptr
();
...
@@ -29,6 +31,7 @@ void swap_blocks(
...
@@ -29,6 +31,7 @@ void swap_blocks(
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for
(
const
auto
&
pair
:
block_mapping
)
{
for
(
const
auto
&
pair
:
block_mapping
)
{
int64_t
src_block_number
=
pair
.
first
;
int64_t
src_block_number
=
pair
.
first
;
int64_t
dst_block_number
=
pair
.
second
;
int64_t
dst_block_number
=
pair
.
second
;
...
@@ -122,7 +125,9 @@ void copy_blocks(
...
@@ -122,7 +125,9 @@ void copy_blocks(
dim3
grid
(
num_layers
,
num_pairs
);
dim3
grid
(
num_layers
,
num_pairs
);
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
dim3
block
(
std
::
min
(
1024
,
numel_per_block
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
key_caches
[
0
].
scalar_type
(),
"copy_blocks_kernel"
,
([
&
]
{
cacheflow
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
cacheflow
::
copy_blocks_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
key_cache_ptrs_tensor
.
data_ptr
<
int64_t
>
(),
...
@@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel(
...
@@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel(
}
}
}
}
}
// namespace cacheflow
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
)
// [num_tokens]
{
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key
.
scalar_type
(),
"reshape_and_cache_kernel"
,
[
&
]
{
cacheflow
::
reshape_and_cache_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
});
}
namespace
cacheflow
{
// Grid: (num_blocks, block_size).
// Grid: (num_blocks, block_size).
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
gather_cached_kv_kernel
(
__global__
void
gather_cached_kv_kernel
(
...
@@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized(
...
@@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized(
}
// namespace cacheflow
}
// namespace cacheflow
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key_cache
,
// [num_blocks, num_heads, head_size/x, block_size, x]
torch
::
Tensor
&
value_cache
,
// [num_blocks, num_heads, head_size, block_size]
torch
::
Tensor
&
slot_mapping
)
// [num_tokens]
{
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
3
);
int
x
=
key_cache
.
size
(
4
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
key
.
scalar_type
(),
"reshape_and_cache_kernel"
,
[
&
]
{
cacheflow
::
reshape_and_cache_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
key_cache
.
data_ptr
<
scalar_t
>
(),
value_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int
>
(),
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
,
x
);
});
}
void
gather_cached_kv
(
void
gather_cached_kv
(
torch
::
Tensor
&
key
,
// [out] [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [out] [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [out] [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [out] [num_tokens, num_heads, head_size]
...
@@ -354,7 +364,9 @@ void gather_cached_kv(
...
@@ -354,7 +364,9 @@ void gather_cached_kv(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
key
.
scalar_type
(),
key
.
scalar_type
(),
"gather_cached_kv_kernel_optimized"
,
"gather_cached_kv_kernel_optimized"
,
[
&
]
{
[
&
]
{
...
...
csrc/layernorm_kernels.cu
View file @
e070829a
...
@@ -46,7 +46,9 @@ void rms_norm(
...
@@ -46,7 +46,9 @@ void rms_norm(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
input
.
scalar_type
(),
"rms_norm_kernel"
,
"rms_norm_kernel"
,
[
&
]
{
[
&
]
{
...
...
csrc/pos_encoding_kernels.cu
View file @
e070829a
...
@@ -64,7 +64,9 @@ void rotary_embedding_neox(
...
@@ -64,7 +64,9 @@ void rotary_embedding_neox(
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
query
.
scalar_type
(),
query
.
scalar_type
(),
"rotary_embedding_neox"
,
"rotary_embedding_neox"
,
[
&
]
{
[
&
]
{
...
...
setup.py
View file @
e070829a
import
setuptools
import
setuptools
import
torch
from
torch.utils
import
cpp_extension
from
torch.utils
import
cpp_extension
CXX_FLAGS
=
[
'-g'
]
CXX_FLAGS
=
[
'-g'
]
NVCC_FLAGS
=
[
'-O2'
]
NVCC_FLAGS
=
[
'-O2'
]
if
not
torch
.
cuda
.
is_available
():
raise
RuntimeError
(
f
'Cannot find CUDA at CUDA_HOME:
{
cpp_extension
.
CUDA_HOME
}
. '
'CUDA must be available in order to build the package.'
)
# FIXME(woosuk): Consider the case where the machine has multiple GPUs with
# different compute capabilities.
compute_capability
=
torch
.
cuda
.
get_device_capability
()
major
,
minor
=
compute_capability
# Enable bfloat16 support if the compute capability is >= 8.0.
if
major
>=
8
:
NVCC_FLAGS
.
append
(
'-DENABLE_BF16'
)
ext_modules
=
[]
ext_modules
=
[]
...
@@ -23,7 +36,7 @@ attention_extension = cpp_extension.CUDAExtension(
...
@@ -23,7 +36,7 @@ attention_extension = cpp_extension.CUDAExtension(
)
)
ext_modules
.
append
(
attention_extension
)
ext_modules
.
append
(
attention_extension
)
# Positional encodings.
# Positional encoding
kernel
s.
positional_encoding_extension
=
cpp_extension
.
CUDAExtension
(
positional_encoding_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.pos_encoding_ops'
,
name
=
'cacheflow.pos_encoding_ops'
,
sources
=
[
'csrc/pos_encoding.cpp'
,
'csrc/pos_encoding_kernels.cu'
],
sources
=
[
'csrc/pos_encoding.cpp'
,
'csrc/pos_encoding_kernels.cu'
],
...
@@ -39,6 +52,7 @@ layernorm_extension = cpp_extension.CUDAExtension(
...
@@ -39,6 +52,7 @@ layernorm_extension = cpp_extension.CUDAExtension(
)
)
ext_modules
.
append
(
layernorm_extension
)
ext_modules
.
append
(
layernorm_extension
)
# Activation kernels.
activation_extension
=
cpp_extension
.
CUDAExtension
(
activation_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.activation_ops'
,
name
=
'cacheflow.activation_ops'
,
sources
=
[
'csrc/activation.cpp'
,
'csrc/activation_kernels.cu'
],
sources
=
[
'csrc/activation.cpp'
,
'csrc/activation_kernels.cu'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment