Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
0d794af6
Unverified
Commit
0d794af6
authored
Feb 12, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 12, 2024
Browse files
feat: experimental support for cuda graphs (#1428)
Co-authored-by:
Nicolas Patry
<
patry.nicolas@protonmail.com
>
parent
53214633
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
300 additions
and
58 deletions
+300
-58
Dockerfile
Dockerfile
+2
-2
Dockerfile_amd
Dockerfile_amd
+1
-1
docs/source/basic_tutorials/launcher.md
docs/source/basic_tutorials/launcher.md
+8
-0
integration-tests/conftest.py
integration-tests/conftest.py
+4
-1
launcher/src/main.rs
launcher/src/main.rs
+13
-1
server/Makefile-awq
server/Makefile-awq
+4
-2
server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu
...er/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu
+3
-2
server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh
...r/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh
+4
-4
server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu
...er/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu
+7
-4
server/exllama_kernels/exllama_kernels/exllama_ext.cpp
server/exllama_kernels/exllama_kernels/exllama_ext.cpp
+5
-1
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
+3
-2
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
+7
-4
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
...n_server/models/custom_modeling/flash_mistral_modeling.py
+6
-2
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
...n_server/models/custom_modeling/flash_mixtral_modeling.py
+6
-2
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+117
-14
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+108
-15
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+2
-1
No files found.
Dockerfile
View file @
0d794af6
# Rust builder
# Rust builder
FROM
lukemathwalker/cargo-chef:latest-rust-1.7
1
AS chef
FROM
lukemathwalker/cargo-chef:latest-rust-1.7
5
AS chef
WORKDIR
/usr/src
WORKDIR
/usr/src
ARG
CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
ARG
CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
...
@@ -166,7 +166,7 @@ FROM kernel-builder as megablocks-builder
...
@@ -166,7 +166,7 @@ FROM kernel-builder as megablocks-builder
RUN
pip
install
git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
RUN
pip
install
git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
# Text Generation Inference base image
# Text Generation Inference base image
FROM
nvidia/cuda:12.1.0-base-ubuntu2
0
.04 as base
FROM
nvidia/cuda:12.1.0-base-ubuntu2
2
.04 as base
# Conda env
# Conda env
ENV
PATH=/opt/conda/bin:$PATH \
ENV
PATH=/opt/conda/bin:$PATH \
...
...
Dockerfile_amd
View file @
0d794af6
# Rust builder
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.7
1
AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.7
5
AS chef
WORKDIR /usr/src
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
...
...
docs/source/basic_tutorials/launcher.md
View file @
0d794af6
...
@@ -205,6 +205,14 @@ Options:
...
@@ -205,6 +205,14 @@ Options:
[
env
:
MAX_BATCH_SIZE
=]
[
env
:
MAX_BATCH_SIZE
=]
```
## ENABLE_CUDA_GRAPHS
```
shell
--enable-cuda-graphs
Enable experimental support
for
cuda graphs
[
env
:
ENABLE_CUDA_GRAPHS
=]
```
```
## HOSTNAME
## HOSTNAME
```
shell
```
shell
...
...
integration-tests/conftest.py
View file @
0d794af6
...
@@ -317,7 +317,10 @@ def launcher(event_loop):
...
@@ -317,7 +317,10 @@ def launcher(event_loop):
gpu_count
=
num_shard
if
num_shard
is
not
None
else
1
gpu_count
=
num_shard
if
num_shard
is
not
None
else
1
env
=
{
"LOG_LEVEL"
:
"info,text_generation_router=debug"
}
env
=
{
"LOG_LEVEL"
:
"info,text_generation_router=debug"
,
"ENABLE_CUDA_GRAPHS"
:
"true"
,
}
if
not
use_flash_attention
:
if
not
use_flash_attention
:
env
[
"USE_FLASH_ATTENTION"
]
=
"false"
env
[
"USE_FLASH_ATTENTION"
]
=
"false"
...
...
launcher/src/main.rs
View file @
0d794af6
...
@@ -284,6 +284,10 @@ struct Args {
...
@@ -284,6 +284,10 @@ struct Args {
#[clap(long,
env)]
#[clap(long,
env)]
max_batch_size
:
Option
<
usize
>
,
max_batch_size
:
Option
<
usize
>
,
/// Enable experimental support for cuda graphs
#[clap(long,
env)]
enable_cuda_graphs
:
bool
,
/// The IP address to listen on
/// The IP address to listen on
#[clap(default_value
=
"0.0.0.0"
,
long,
env)]
#[clap(default_value
=
"0.0.0.0"
,
long,
env)]
hostname
:
String
,
hostname
:
String
,
...
@@ -407,6 +411,7 @@ fn shard_manager(
...
@@ -407,6 +411,7 @@ fn shard_manager(
disable_custom_kernels
:
bool
,
disable_custom_kernels
:
bool
,
watermark_gamma
:
Option
<
f32
>
,
watermark_gamma
:
Option
<
f32
>
,
watermark_delta
:
Option
<
f32
>
,
watermark_delta
:
Option
<
f32
>
,
enable_cuda_graphs
:
bool
,
cuda_memory_fraction
:
f32
,
cuda_memory_fraction
:
f32
,
rope_scaling
:
Option
<
RopeScaling
>
,
rope_scaling
:
Option
<
RopeScaling
>
,
rope_factor
:
Option
<
f32
>
,
rope_factor
:
Option
<
f32
>
,
...
@@ -488,7 +493,7 @@ fn shard_manager(
...
@@ -488,7 +493,7 @@ fn shard_manager(
envs
.push
((
"WORLD_SIZE"
.into
(),
world_size
.to_string
()
.into
()));
envs
.push
((
"WORLD_SIZE"
.into
(),
world_size
.to_string
()
.into
()));
envs
.push
((
"MASTER_ADDR"
.into
(),
master_addr
.into
()));
envs
.push
((
"MASTER_ADDR"
.into
(),
master_addr
.into
()));
envs
.push
((
"MASTER_PORT"
.into
(),
master_port
.to_string
()
.into
()));
envs
.push
((
"MASTER_PORT"
.into
(),
master_port
.to_string
()
.into
()));
envs
.push
((
"
NCCL_ASYNC_ERROR_HANDLING
"
.into
(),
"1"
.into
()));
envs
.push
((
"
TORCH_NCCL_AVOID_RECORD_STREAMS
"
.into
(),
"1"
.into
()));
// CUDA memory fraction
// CUDA memory fraction
envs
.push
((
envs
.push
((
...
@@ -538,6 +543,11 @@ fn shard_manager(
...
@@ -538,6 +543,11 @@ fn shard_manager(
));
));
};
};
// Enable experimental support for cuda graphs
if
enable_cuda_graphs
{
envs
.push
((
"ENABLE_CUDA_GRAPHS"
.into
(),
"True"
.into
()))
}
// If disable_custom_kernels is true, pass it to the shard as an env var
// If disable_custom_kernels is true, pass it to the shard as an env var
if
disable_custom_kernels
{
if
disable_custom_kernels
{
envs
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
envs
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
...
@@ -926,6 +936,7 @@ fn spawn_shards(
...
@@ -926,6 +936,7 @@ fn spawn_shards(
let
disable_custom_kernels
=
args
.disable_custom_kernels
;
let
disable_custom_kernels
=
args
.disable_custom_kernels
;
let
watermark_gamma
=
args
.watermark_gamma
;
let
watermark_gamma
=
args
.watermark_gamma
;
let
watermark_delta
=
args
.watermark_delta
;
let
watermark_delta
=
args
.watermark_delta
;
let
enable_cuda_graphs
=
args
.enable_cuda_graphs
;
let
cuda_memory_fraction
=
args
.cuda_memory_fraction
;
let
cuda_memory_fraction
=
args
.cuda_memory_fraction
;
let
rope_scaling
=
args
.rope_scaling
;
let
rope_scaling
=
args
.rope_scaling
;
let
rope_factor
=
args
.rope_factor
;
let
rope_factor
=
args
.rope_factor
;
...
@@ -947,6 +958,7 @@ fn spawn_shards(
...
@@ -947,6 +958,7 @@ fn spawn_shards(
disable_custom_kernels
,
disable_custom_kernels
,
watermark_gamma
,
watermark_gamma
,
watermark_delta
,
watermark_delta
,
enable_cuda_graphs
,
cuda_memory_fraction
,
cuda_memory_fraction
,
rope_scaling
,
rope_scaling
,
rope_factor
,
rope_factor
,
...
...
server/Makefile-awq
View file @
0d794af6
awq_commit := f084f40bd996f3cf3a0633c1ad7d9d476c318aaa
# Fork that adds only the correct stream to this kernel in order
# to make cuda graphs work.
awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
awq:
awq:
rm -rf llm-awq
rm -rf llm-awq
git clone https://github.com/
mit-han-lab
/llm-awq
git clone https://github.com/
huggingface
/llm-awq
build-awq: awq
build-awq: awq
cd llm-awq/ && git fetch && git checkout $(awq_commit)
cd llm-awq/ && git fetch && git checkout $(awq_commit)
...
...
server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu
View file @
0d794af6
#include "q4_matmul.cuh"
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "column_remap.cuh"
#include <ATen/cuda/CUDAContext.h>
#include "../util.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#include "../matrix.cuh"
#include "../cu_compat.cuh"
#include "../cu_compat.cuh"
...
@@ -224,8 +225,8 @@ void q4_matmul_recons_cuda
...
@@ -224,8 +225,8 @@ void q4_matmul_recons_cuda
const
int
x_height
,
const
int
x_height
,
Q4Matrix
*
w
,
Q4Matrix
*
w
,
half
*
out
,
half
*
out
,
const
cublasHandle_t
handle
,
bool
no_zero
,
bool
no_zero
const
cublasHandle_t
handle
)
)
{
{
int
height
=
x_height
;
int
height
=
x_height
;
...
...
server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh
View file @
0d794af6
...
@@ -19,8 +19,8 @@ void q4_matmul_cuda
...
@@ -19,8 +19,8 @@ void q4_matmul_cuda
const
int
x_height
,
const
int
x_height
,
const
Q4Matrix
*
w
,
const
Q4Matrix
*
w
,
half
*
out
,
half
*
out
,
bool
no_zero
=
false
,
bool
no_zero
,
cudaStream_t
alt_stream
=
NULL
cudaStream_t
alt_stream
);
);
void
q4_matmul_recons_cuda
void
q4_matmul_recons_cuda
...
@@ -30,8 +30,8 @@ void q4_matmul_recons_cuda
...
@@ -30,8 +30,8 @@ void q4_matmul_recons_cuda
const
int
x_height
,
const
int
x_height
,
Q4Matrix
*
w
,
Q4Matrix
*
w
,
half
*
out
,
half
*
out
,
const
cublasHandle_t
handle
,
bool
no_zero
,
bool
no_zero
=
fals
e
const
cublasHandle_t
handl
e
);
);
#endif
#endif
server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu
View file @
0d794af6
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include "q4_matrix.cuh"
#include <vector>
#include <vector>
#include "../util.cuh"
#include "../util.cuh"
...
@@ -90,7 +91,7 @@ __global__ void make_sequential_kernel
...
@@ -90,7 +91,7 @@ __global__ void make_sequential_kernel
int
w2_row_shift
=
w2_subrow
<<
2
;
int
w2_row_shift
=
w2_subrow
<<
2
;
int
wnew2_row_shift
=
i
<<
2
;
int
wnew2_row_shift
=
i
<<
2
;
uint64_t
src
=
w2
[
w2_row
*
w2_stride
+
w2_column
];
uint64_t
src
=
w2
[
w2_row
*
w2_stride
+
w2_column
];
src
>>=
w2_row_shift
;
src
>>=
w2_row_shift
;
src
&=
0x0000000f0000000f
;
src
&=
0x0000000f0000000f
;
src
<<=
wnew2_row_shift
;
src
<<=
wnew2_row_shift
;
...
@@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
...
@@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
dim3
threads
(
UNSHUF_BLOCKSIZE_X
,
1
,
1
);
dim3
threads
(
UNSHUF_BLOCKSIZE_X
,
1
,
1
);
dim3
blocks
(
width
/
UNSHUF_BLOCKSIZE_X
/
2
,
height
/
8
,
1
);
dim3
blocks
(
width
/
UNSHUF_BLOCKSIZE_X
/
2
,
height
/
8
,
1
);
make_sequential_kernel
<<<
blocks
,
threads
>>>
(
cuda_qweight
,
cuda_new_qweight
,
cuda_x_map
,
height
/
8
,
width
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
make_sequential_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
cuda_qweight
,
cuda_new_qweight
,
cuda_x_map
,
height
/
8
,
width
);
// Replace qweights
// Replace qweights
...
@@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out)
...
@@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out)
1
1
);
);
reconstruct_kernel
<<<
blocks
,
threads
>>>
(
cuda_qweight
,
out
,
cuda_scales
,
cuda_qzeros
,
height
/
8
,
width
,
groupsize
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
}
reconstruct_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
cuda_qweight
,
out
,
cuda_scales
,
cuda_qzeros
,
height
/
8
,
width
,
groupsize
);
\ No newline at end of file
}
server/exllama_kernels/exllama_kernels/exllama_ext.cpp
View file @
0d794af6
...
@@ -183,6 +183,7 @@ void q4_matmul
...
@@ -183,6 +183,7 @@ void q4_matmul
int
x_height
=
x
.
size
(
0
);
int
x_height
=
x
.
size
(
0
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
tuningParams
.
matmul_recons_thd
==
0
||
x_height
<
tuningParams
.
matmul_recons_thd
)
if
(
tuningParams
.
matmul_recons_thd
==
0
||
x_height
<
tuningParams
.
matmul_recons_thd
)
{
{
q4_matmul_cuda
q4_matmul_cuda
...
@@ -191,7 +192,9 @@ void q4_matmul
...
@@ -191,7 +192,9 @@ void q4_matmul
(
half
*
)
x
.
data_ptr
(),
(
half
*
)
x
.
data_ptr
(),
x_height
,
x_height
,
wm
,
wm
,
(
half
*
)
out
.
data_ptr
()
(
half
*
)
out
.
data_ptr
(),
false
,
stream
);
);
}
}
else
else
...
@@ -203,6 +206,7 @@ void q4_matmul
...
@@ -203,6 +206,7 @@ void q4_matmul
x_height
,
x_height
,
wm
,
wm
,
(
half
*
)
out
.
data_ptr
(),
(
half
*
)
out
.
data_ptr
(),
false
,
at
::
cuda
::
getCurrentCUDABlasHandle
()
at
::
cuda
::
getCurrentCUDABlasHandle
()
);
);
}
}
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
View file @
0d794af6
...
@@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part
...
@@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part
bool
mul_r_weights
bool
mul_r_weights
)
)
{
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
!
b
->
is_gptq
)
if
(
!
b
->
is_gptq
)
{
{
dim3
blockDim
,
gridDim
;
dim3
blockDim
,
gridDim
;
...
@@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part
...
@@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part
fp_gemm_half_q_half_kernel
kernel
=
pick_gemm_half_q_half_kernel
(
m_count
,
r_weights
!=
NULL
,
mul_r_weights
);
fp_gemm_half_q_half_kernel
kernel
=
pick_gemm_half_q_half_kernel
(
m_count
,
r_weights
!=
NULL
,
mul_r_weights
);
kernel
<<<
gridDim
,
blockDim
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
a
,
a
,
b
->
cuda_q_weight
,
b
->
cuda_q_weight
,
...
@@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part
...
@@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part
// print_global_mem(r_weights, 1, 1, 1);
// print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);
// DBGI(r_weights_stride);
kernel
<<<
gridDim
,
blockDim
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
a
,
a
,
b
->
cuda_q_weight
,
b
->
cuda_q_weight
,
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
View file @
0d794af6
...
@@ -168,8 +168,9 @@ QMatrix::QMatrix
...
@@ -168,8 +168,9 @@ QMatrix::QMatrix
blockDim
.
y
=
1
;
blockDim
.
y
=
1
;
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
1
;
gridDim
.
y
=
1
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
shuffle_kernel
<<<
gridDim
,
blockDim
>>>
(
cuda_q_weight
,
height
,
width
,
rows_8
,
rows_6
,
rows_5
,
rows_4
,
rows_3
,
rows_2
);
shuffle_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
cuda_q_weight
,
height
,
width
,
rows_8
,
rows_6
,
rows_5
,
rows_4
,
rows_3
,
rows_2
);
}
}
QMatrix
::~
QMatrix
()
QMatrix
::~
QMatrix
()
...
@@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out)
...
@@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out)
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
blockDim
.
y
=
1
;
gridDim
.
y
=
DIVIDE
(
height
,
BLOCK_KN_SIZE
);
gridDim
.
y
=
DIVIDE
(
height
,
BLOCK_KN_SIZE
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
!
is_gptq
)
if
(
!
is_gptq
)
{
{
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
reconstruct_kernel
<<<
gridDim
,
blockDim
>>>
reconstruct_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
cuda_q_weight
,
cuda_q_weight
,
cuda_q_perm
,
cuda_q_perm
,
...
@@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out)
...
@@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out)
else
else
{
{
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
*
4
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
*
4
);
reconstruct_gptq_kernel
<<<
gridDim
,
blockDim
>>>
reconstruct_gptq_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
cuda_q_weight
,
cuda_q_weight
,
cuda_q_perm
,
cuda_q_perm
,
...
@@ -563,6 +565,7 @@ __global__ void make_sequential_kernel
...
@@ -563,6 +565,7 @@ __global__ void make_sequential_kernel
bool
QMatrix
::
make_sequential
(
const
uint32_t
*
cpu_g_idx
)
bool
QMatrix
::
make_sequential
(
const
uint32_t
*
cpu_g_idx
)
{
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
uint32_t
*
cuda_new_qweight
=
NULL
;
uint32_t
*
cuda_new_qweight
=
NULL
;
cudaError_t
err
=
cudaMalloc
(
&
cuda_new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
));
cudaError_t
err
=
cudaMalloc
(
&
cuda_new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
));
if
(
err
!=
cudaSuccess
)
{
if
(
err
!=
cudaSuccess
)
{
...
@@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
...
@@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
height
/
8
;
gridDim
.
y
=
height
/
8
;
make_sequential_kernel
<<<
gridDim
,
blockDim
>>>
make_sequential_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
cuda_q_weight
,
cuda_q_weight
,
cuda_new_qweight
,
cuda_new_qweight
,
...
...
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
View file @
0d794af6
...
@@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module):
...
@@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module):
weights
=
weights
,
weights
=
weights
,
)
)
self
.
max_past
=
config
.
sliding_window
self
.
max_past
=
config
.
sliding_window
self
.
max_past_tensor
=
(
torch
.
tensor
(
config
.
sliding_window
,
device
=
weights
.
device
)
if
self
.
max_past
is
not
None
else
None
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
...
@@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif
self
.
max_past
is
not
None
:
elif
self
.
max_past
is
not
None
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
# kernel requires the true values
max_s
=
min
(
self
.
max_past
,
max_s
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past_tensor
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past
)
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
...
...
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
View file @
0d794af6
...
@@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module):
...
@@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module):
weights
=
weights
,
weights
=
weights
,
)
)
self
.
max_past
=
config
.
sliding_window
self
.
max_past
=
config
.
sliding_window
self
.
max_past_tensor
=
(
torch
.
tensor
(
config
.
sliding_window
,
device
=
weights
.
device
)
if
self
.
max_past
is
not
None
else
None
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
...
@@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif
self
.
max_past
is
not
None
:
elif
self
.
max_past
is
not
None
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
# kernel requires the true values
max_s
=
min
(
self
.
max_past
,
max_s
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past_tensor
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past
)
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
0d794af6
import
math
import
math
import
os
import
time
import
time
import
itertools
import
itertools
import
torch
import
torch
...
@@ -6,6 +7,7 @@ import torch.distributed
...
@@ -6,6 +7,7 @@ import torch.distributed
import
numpy
as
np
import
numpy
as
np
from
loguru
import
logger
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
...
@@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
...
@@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer
=
trace
.
get_tracer
(
__name__
)
tracer
=
trace
.
get_tracer
(
__name__
)
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
@
dataclass
@
dataclass
class
FlashCausalLMBatch
(
Batch
):
class
FlashCausalLMBatch
(
Batch
):
...
@@ -62,7 +66,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -62,7 +66,7 @@ class FlashCausalLMBatch(Batch):
# Set in prefill by the CacheManager
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
# list of length b of list of length s_i // block_size
block_tables
:
Optional
[
List
[
List
[
int
]]]
block_tables
:
Optional
[
List
[
List
[
int
]]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
# tensor of size [b, max_
total_
seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor
:
Optional
[
torch
.
Tensor
]
block_tables_tensor
:
Optional
[
torch
.
Tensor
]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots
:
Optional
[
torch
.
Tensor
]
slots
:
Optional
[
torch
.
Tensor
]
...
@@ -663,6 +667,8 @@ class FlashCausalLM(Model):
...
@@ -663,6 +667,8 @@ class FlashCausalLM(Model):
self
.
num_kv_heads
=
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
cuda_graphs
=
{}
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
model
=
model
,
model
=
model
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -678,7 +684,60 @@ class FlashCausalLM(Model):
...
@@ -678,7 +684,60 @@ class FlashCausalLM(Model):
def
batch_type
(
self
)
->
Type
[
FlashCausalLMBatch
]:
def
batch_type
(
self
)
->
Type
[
FlashCausalLMBatch
]:
return
FlashCausalLMBatch
return
FlashCausalLMBatch
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lengths
=
torch
.
ones
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
*
max_s
block_tables
=
(
torch
.
arange
(
max_bt
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
.
repeat
(
bs
)
.
reshape
((
bs
,
max_bt
))
)
kv_cache
=
get_cache_manager
().
kv_cache
self
.
cuda_graphs
[
bs
]
=
{
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"kv_cache"
:
kv_cache
,
"block_tables"
:
block_tables
,
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
}
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
cuda_graphs
[
bs
][
"graph"
]
=
graph
torch
.
cuda
.
synchronize
()
# Run once outside to warmup
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
MEM_POOL
):
self
.
cuda_graphs
[
bs
][
"logits"
]
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
def
warmup
(
self
,
batch
:
FlashCausalLMBatch
):
def
warmup
(
self
,
batch
:
FlashCausalLMBatch
):
# The warmup batch is the biggest batch we could ever receive
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
try
:
try
:
cache_manager
=
set_cache_manager
(
cache_manager
=
set_cache_manager
(
...
@@ -690,6 +749,8 @@ class FlashCausalLM(Model):
...
@@ -690,6 +749,8 @@ class FlashCausalLM(Model):
self
.
dtype
,
self
.
dtype
,
self
.
device
,
self
.
device
,
)
)
max_bt
=
batch
.
max_blocks
max_s
=
max_bt
*
get_cache_manager
().
block_size
_
,
batch
,
_
=
self
.
generate_token
(
batch
)
_
,
batch
,
_
=
self
.
generate_token
(
batch
)
except
torch
.
cuda
.
OutOfMemoryError
as
e
:
except
torch
.
cuda
.
OutOfMemoryError
as
e
:
raise
RuntimeError
(
raise
RuntimeError
(
...
@@ -713,7 +774,8 @@ class FlashCausalLM(Model):
...
@@ -713,7 +774,8 @@ class FlashCausalLM(Model):
)
)
num_blocks
=
(
num_blocks
=
(
int
(
free_memory
//
total_cache_size
)
# Leave 5% for some wiggle room
int
((
free_memory
*
0.95
)
//
total_cache_size
)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+
cache_manager
.
num_blocks
+
cache_manager
.
num_blocks
)
)
...
@@ -731,9 +793,19 @@ class FlashCausalLM(Model):
...
@@ -731,9 +793,19 @@ class FlashCausalLM(Model):
self
.
device
,
self
.
device
,
)
)
if
os
.
getenv
(
"ENABLE_CUDA_GRAPHS"
,
"False"
)
==
"True"
:
try
:
logger
.
info
(
"Experimental support for Cuda Graphs is enabled"
)
# Warmup cuda graphs
for
bs
in
[
1
,
2
,
4
]
+
[
8
*
i
for
i
in
range
(
8
)]:
if
self
.
speculate
is
None
or
self
.
speculate
+
1
<=
bs
:
self
.
cuda_graph_warmup
(
bs
,
max_s
,
max_bt
)
except
Exception
:
logger
.
exception
(
f
"Decode cuda graph warmup failed"
)
return
int
(
num_blocks
*
BLOCK_SIZE
)
return
int
(
num_blocks
*
BLOCK_SIZE
)
def
forward
(
self
,
batch
:
FlashCausalLMBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
def
forward
(
self
,
batch
:
FlashCausalLMBatch
)
->
torch
.
Tensor
:
# Model Forward
# Model Forward
if
batch
.
speculative_ids
is
not
None
:
if
batch
.
speculative_ids
is
not
None
:
input_ids
=
batch
.
input_ids
input_ids
=
batch
.
input_ids
...
@@ -785,17 +857,48 @@ class FlashCausalLM(Model):
...
@@ -785,17 +857,48 @@ class FlashCausalLM(Model):
max_s
=
batch
.
max_seqlen
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
lm_head_indices
=
batch
.
prefill_head_indices
return
self
.
model
.
forward
(
bs
=
input_ids
.
shape
[
0
]
input_ids
=
input_ids
,
padded_bs
=
bs
position_ids
=
position_ids
,
if
bs
==
3
:
cu_seqlen_prefill
=
cu_seqlen_prefill
,
padded_bs
=
4
kv_cache
=
kv_cache
,
elif
3
<
bs
<=
8
:
block_tables
=
block_tables
,
padded_bs
=
8
slots
=
slots
,
elif
bs
>
8
:
input_lengths
=
input_lengths
,
padded_bs
=
(
bs
+
7
)
//
8
*
8
max_s
=
max_s
,
lm_head_indices
=
lm_head_indices
,
# Try to find an associated cuda graph
)
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
or
batch
.
speculative_ids
is
not
None
:
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
lm_head_indices
=
lm_head_indices
,
)
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph
[
"input_ids"
][:
input_ids
.
shape
[
0
]]
=
input_ids
cuda_graph
[
"position_ids"
][:
position_ids
.
shape
[
0
]]
=
position_ids
cuda_graph
[
"block_tables"
][
:
block_tables
.
shape
[
0
],
:
block_tables
.
shape
[
1
]
]
=
block_tables
cuda_graph
[
"slots"
].
fill_
(
-
1
)
cuda_graph
[
"slots"
][:
slots
.
shape
[
0
]]
=
slots
cuda_graph
[
"input_lengths"
].
zero_
()
cuda_graph
[
"input_lengths"
][:
input_lengths
.
shape
[
0
]]
=
input_lengths
# Replay the graph
cuda_graph
[
"graph"
].
replay
()
# Slice output to the correct shape
return
cuda_graph
[
"logits"
][:
bs
]
@
tracer
.
start_as_current_span
(
"generate_token"
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
def
generate_token
(
...
...
server/text_generation_server/models/flash_mistral.py
View file @
0d794af6
...
@@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__)
...
@@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW
:
Optional
[
int
]
=
None
SLIDING_WINDOW
:
Optional
[
int
]
=
None
SLIDING_WINDOW_BLOCKS
:
Optional
[
int
]
=
None
SLIDING_WINDOW_BLOCKS
:
Optional
[
int
]
=
None
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
# Adds windowing logic to FlashCausalLMBatch
# Adds windowing logic to FlashCausalLMBatch
@
dataclass
@
dataclass
...
@@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM):
...
@@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM):
model
=
model_cls
(
config
,
weights
)
model
=
model_cls
(
config
,
weights
)
self
.
cuda_graphs
=
{}
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
BaseFlashMistral
,
self
).
__init__
(
super
(
BaseFlashMistral
,
self
).
__init__
(
model
=
model
,
model
=
model
,
...
@@ -350,6 +354,60 @@ class BaseFlashMistral(FlashCausalLM):
...
@@ -350,6 +354,60 @@ class BaseFlashMistral(FlashCausalLM):
def
batch_type
(
self
)
->
Type
[
FlashMistralBatch
]:
def
batch_type
(
self
)
->
Type
[
FlashMistralBatch
]:
return
FlashMistralBatch
return
FlashMistralBatch
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lengths
=
torch
.
ones
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
*
max_s
block_tables
=
(
torch
.
arange
(
max_bt
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
.
repeat
(
bs
)
.
reshape
((
bs
,
max_bt
))
)
kv_cache
=
get_cache_manager
().
kv_cache
self
.
cuda_graphs
[
bs
]
=
{
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"kv_cache"
:
kv_cache
,
"block_tables"
:
block_tables
,
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
}
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
cuda_graphs
[
bs
][
"graph"
]
=
graph
torch
.
cuda
.
synchronize
()
# Run once outside to warmup
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
MEM_POOL
):
self
.
cuda_graphs
[
bs
][
"logits"
]
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
def
forward
(
self
,
batch
:
FlashMistralBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
forward
(
self
,
batch
:
FlashMistralBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
# Model Forward
if
batch
.
speculative_ids
is
not
None
:
if
batch
.
speculative_ids
is
not
None
:
...
@@ -401,21 +459,56 @@ class BaseFlashMistral(FlashCausalLM):
...
@@ -401,21 +459,56 @@ class BaseFlashMistral(FlashCausalLM):
input_lengths
=
batch
.
input_lengths_tensor
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
lm_head_indices
=
batch
.
prefill_head_indices
logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
if
self
.
model
.
max_past
is
not
None
:
position_ids
=
position_ids
,
max_s
=
min
(
self
.
model
.
max_past
,
max_s
)
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
bs
=
input_ids
.
shape
[
0
]
block_tables
=
block_tables
,
padded_bs
=
bs
slots
=
slots
,
if
bs
==
3
:
input_lengths
=
input_lengths
,
padded_bs
=
4
max_s
=
max_s
,
elif
3
<
bs
<=
8
:
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
padded_bs
=
8
lm_head_indices
=
lm_head_indices
,
elif
bs
>
8
:
)
padded_bs
=
(
bs
+
7
)
//
8
*
8
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
# Try to find an associated cuda graph
return
logits
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
return
logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph
[
"input_ids"
][:
input_ids
.
shape
[
0
]]
=
input_ids
cuda_graph
[
"position_ids"
][:
position_ids
.
shape
[
0
]]
=
position_ids
cuda_graph
[
"block_tables"
][
:
block_tables
.
shape
[
0
],
:
block_tables
.
shape
[
1
]
]
=
block_tables
cuda_graph
[
"slots"
].
fill_
(
-
1
)
cuda_graph
[
"slots"
][:
slots
.
shape
[
0
]]
=
slots
cuda_graph
[
"input_lengths"
].
zero_
()
cuda_graph
[
"input_lengths"
][:
input_lengths
.
shape
[
0
]]
=
input_lengths
# Replay the graph
cuda_graph
[
"graph"
].
replay
()
# Slice output to the correct shape
return
cuda_graph
[
"logits"
][:
bs
]
class
FlashMistral
(
BaseFlashMistral
):
class
FlashMistral
(
BaseFlashMistral
):
...
...
server/text_generation_server/utils/weights.py
View file @
0d794af6
...
@@ -407,8 +407,9 @@ class Weights:
...
@@ -407,8 +407,9 @@ class Weights:
data
=
json
.
load
(
f
)
data
=
json
.
load
(
f
)
self
.
gptq_bits
=
data
[
"quantization_config"
][
"bits"
]
self
.
gptq_bits
=
data
[
"quantization_config"
][
"bits"
]
self
.
gptq_groupsize
=
data
[
"quantization_config"
][
"group_size"
]
self
.
gptq_groupsize
=
data
[
"quantization_config"
][
"group_size"
]
self
.
gptq_desc_act
=
data
[
"quantization_config"
][
"desc_act"
]
# Order is important here, desc_act is missing on some real models
self
.
quant_method
=
data
[
"quantization_config"
][
"quant_method"
]
self
.
quant_method
=
data
[
"quantization_config"
][
"quant_method"
]
self
.
gptq_desc_act
=
data
[
"quantization_config"
][
"desc_act"
]
except
Exception
:
except
Exception
:
filename
=
"quantize_config.json"
filename
=
"quantize_config.json"
try
:
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment