Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
0d794af6
Unverified
Commit
0d794af6
authored
Feb 12, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 12, 2024
Browse files
feat: experimental support for cuda graphs (#1428)
Co-authored-by:
Nicolas Patry
<
patry.nicolas@protonmail.com
>
parent
53214633
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
300 additions
and
58 deletions
+300
-58
Dockerfile
Dockerfile
+2
-2
Dockerfile_amd
Dockerfile_amd
+1
-1
docs/source/basic_tutorials/launcher.md
docs/source/basic_tutorials/launcher.md
+8
-0
integration-tests/conftest.py
integration-tests/conftest.py
+4
-1
launcher/src/main.rs
launcher/src/main.rs
+13
-1
server/Makefile-awq
server/Makefile-awq
+4
-2
server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu
...er/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu
+3
-2
server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh
...r/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh
+4
-4
server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu
...er/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu
+7
-4
server/exllama_kernels/exllama_kernels/exllama_ext.cpp
server/exllama_kernels/exllama_kernels/exllama_ext.cpp
+5
-1
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
+3
-2
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
+7
-4
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
...n_server/models/custom_modeling/flash_mistral_modeling.py
+6
-2
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
...n_server/models/custom_modeling/flash_mixtral_modeling.py
+6
-2
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+117
-14
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+108
-15
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+2
-1
No files found.
Dockerfile
View file @
0d794af6
# Rust builder
FROM
lukemathwalker/cargo-chef:latest-rust-1.7
1
AS chef
FROM
lukemathwalker/cargo-chef:latest-rust-1.7
5
AS chef
WORKDIR
/usr/src
ARG
CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
...
...
@@ -166,7 +166,7 @@ FROM kernel-builder as megablocks-builder
RUN
pip
install
git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
# Text Generation Inference base image
FROM
nvidia/cuda:12.1.0-base-ubuntu2
0
.04 as base
FROM
nvidia/cuda:12.1.0-base-ubuntu2
2
.04 as base
# Conda env
ENV
PATH=/opt/conda/bin:$PATH \
...
...
Dockerfile_amd
View file @
0d794af6
# Rust builder
FROM lukemathwalker/cargo-chef:latest-rust-1.7
1
AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.7
5
AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
...
...
docs/source/basic_tutorials/launcher.md
View file @
0d794af6
...
...
@@ -205,6 +205,14 @@ Options:
[
env
:
MAX_BATCH_SIZE
=]
```
## ENABLE_CUDA_GRAPHS
```
shell
--enable-cuda-graphs
Enable experimental support
for
cuda graphs
[
env
:
ENABLE_CUDA_GRAPHS
=]
```
## HOSTNAME
```
shell
...
...
integration-tests/conftest.py
View file @
0d794af6
...
...
@@ -317,7 +317,10 @@ def launcher(event_loop):
gpu_count
=
num_shard
if
num_shard
is
not
None
else
1
env
=
{
"LOG_LEVEL"
:
"info,text_generation_router=debug"
}
env
=
{
"LOG_LEVEL"
:
"info,text_generation_router=debug"
,
"ENABLE_CUDA_GRAPHS"
:
"true"
,
}
if
not
use_flash_attention
:
env
[
"USE_FLASH_ATTENTION"
]
=
"false"
...
...
launcher/src/main.rs
View file @
0d794af6
...
...
@@ -284,6 +284,10 @@ struct Args {
#[clap(long,
env)]
max_batch_size
:
Option
<
usize
>
,
/// Enable experimental support for cuda graphs
#[clap(long,
env)]
enable_cuda_graphs
:
bool
,
/// The IP address to listen on
#[clap(default_value
=
"0.0.0.0"
,
long,
env)]
hostname
:
String
,
...
...
@@ -407,6 +411,7 @@ fn shard_manager(
disable_custom_kernels
:
bool
,
watermark_gamma
:
Option
<
f32
>
,
watermark_delta
:
Option
<
f32
>
,
enable_cuda_graphs
:
bool
,
cuda_memory_fraction
:
f32
,
rope_scaling
:
Option
<
RopeScaling
>
,
rope_factor
:
Option
<
f32
>
,
...
...
@@ -488,7 +493,7 @@ fn shard_manager(
envs
.push
((
"WORLD_SIZE"
.into
(),
world_size
.to_string
()
.into
()));
envs
.push
((
"MASTER_ADDR"
.into
(),
master_addr
.into
()));
envs
.push
((
"MASTER_PORT"
.into
(),
master_port
.to_string
()
.into
()));
envs
.push
((
"
NCCL_ASYNC_ERROR_HANDLING
"
.into
(),
"1"
.into
()));
envs
.push
((
"
TORCH_NCCL_AVOID_RECORD_STREAMS
"
.into
(),
"1"
.into
()));
// CUDA memory fraction
envs
.push
((
...
...
@@ -538,6 +543,11 @@ fn shard_manager(
));
};
// Enable experimental support for cuda graphs
if
enable_cuda_graphs
{
envs
.push
((
"ENABLE_CUDA_GRAPHS"
.into
(),
"True"
.into
()))
}
// If disable_custom_kernels is true, pass it to the shard as an env var
if
disable_custom_kernels
{
envs
.push
((
"DISABLE_CUSTOM_KERNELS"
.into
(),
"True"
.into
()))
...
...
@@ -926,6 +936,7 @@ fn spawn_shards(
let
disable_custom_kernels
=
args
.disable_custom_kernels
;
let
watermark_gamma
=
args
.watermark_gamma
;
let
watermark_delta
=
args
.watermark_delta
;
let
enable_cuda_graphs
=
args
.enable_cuda_graphs
;
let
cuda_memory_fraction
=
args
.cuda_memory_fraction
;
let
rope_scaling
=
args
.rope_scaling
;
let
rope_factor
=
args
.rope_factor
;
...
...
@@ -947,6 +958,7 @@ fn spawn_shards(
disable_custom_kernels
,
watermark_gamma
,
watermark_delta
,
enable_cuda_graphs
,
cuda_memory_fraction
,
rope_scaling
,
rope_factor
,
...
...
server/Makefile-awq
View file @
0d794af6
awq_commit := f084f40bd996f3cf3a0633c1ad7d9d476c318aaa
# Fork that adds only the correct stream to this kernel in order
# to make cuda graphs work.
awq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4
awq:
rm -rf llm-awq
git clone https://github.com/
mit-han-lab
/llm-awq
git clone https://github.com/
huggingface
/llm-awq
build-awq: awq
cd llm-awq/ && git fetch && git checkout $(awq_commit)
...
...
server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu
View file @
0d794af6
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include <ATen/cuda/CUDAContext.h>
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cu_compat.cuh"
...
...
@@ -224,8 +225,8 @@ void q4_matmul_recons_cuda
const
int
x_height
,
Q4Matrix
*
w
,
half
*
out
,
const
cublasHandle_t
handle
,
bool
no_zero
bool
no_zero
,
const
cublasHandle_t
handle
)
{
int
height
=
x_height
;
...
...
server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh
View file @
0d794af6
...
...
@@ -19,8 +19,8 @@ void q4_matmul_cuda
const
int
x_height
,
const
Q4Matrix
*
w
,
half
*
out
,
bool
no_zero
=
false
,
cudaStream_t
alt_stream
=
NULL
bool
no_zero
,
cudaStream_t
alt_stream
);
void
q4_matmul_recons_cuda
...
...
@@ -30,8 +30,8 @@ void q4_matmul_recons_cuda
const
int
x_height
,
Q4Matrix
*
w
,
half
*
out
,
const
cublasHandle_t
handle
,
bool
no_zero
=
fals
e
bool
no_zero
,
const
cublasHandle_t
handl
e
);
#endif
server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu
View file @
0d794af6
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include <vector>
#include "../util.cuh"
...
...
@@ -90,7 +91,7 @@ __global__ void make_sequential_kernel
int
w2_row_shift
=
w2_subrow
<<
2
;
int
wnew2_row_shift
=
i
<<
2
;
uint64_t
src
=
w2
[
w2_row
*
w2_stride
+
w2_column
];
uint64_t
src
=
w2
[
w2_row
*
w2_stride
+
w2_column
];
src
>>=
w2_row_shift
;
src
&=
0x0000000f0000000f
;
src
<<=
wnew2_row_shift
;
...
...
@@ -146,7 +147,8 @@ void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
dim3
threads
(
UNSHUF_BLOCKSIZE_X
,
1
,
1
);
dim3
blocks
(
width
/
UNSHUF_BLOCKSIZE_X
/
2
,
height
/
8
,
1
);
make_sequential_kernel
<<<
blocks
,
threads
>>>
(
cuda_qweight
,
cuda_new_qweight
,
cuda_x_map
,
height
/
8
,
width
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
make_sequential_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
cuda_qweight
,
cuda_new_qweight
,
cuda_x_map
,
height
/
8
,
width
);
// Replace qweights
...
...
@@ -213,5 +215,6 @@ void Q4Matrix::reconstruct(half* out)
1
);
reconstruct_kernel
<<<
blocks
,
threads
>>>
(
cuda_qweight
,
out
,
cuda_scales
,
cuda_qzeros
,
height
/
8
,
width
,
groupsize
);
}
\ No newline at end of file
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
reconstruct_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
cuda_qweight
,
out
,
cuda_scales
,
cuda_qzeros
,
height
/
8
,
width
,
groupsize
);
}
server/exllama_kernels/exllama_kernels/exllama_ext.cpp
View file @
0d794af6
...
...
@@ -183,6 +183,7 @@ void q4_matmul
int
x_height
=
x
.
size
(
0
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
tuningParams
.
matmul_recons_thd
==
0
||
x_height
<
tuningParams
.
matmul_recons_thd
)
{
q4_matmul_cuda
...
...
@@ -191,7 +192,9 @@ void q4_matmul
(
half
*
)
x
.
data_ptr
(),
x_height
,
wm
,
(
half
*
)
out
.
data_ptr
()
(
half
*
)
out
.
data_ptr
(),
false
,
stream
);
}
else
...
...
@@ -203,6 +206,7 @@ void q4_matmul
x_height
,
wm
,
(
half
*
)
out
.
data_ptr
(),
false
,
at
::
cuda
::
getCurrentCUDABlasHandle
()
);
}
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
View file @
0d794af6
...
...
@@ -38,6 +38,7 @@ void gemm_half_q_half_cuda_part
bool
mul_r_weights
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
!
b
->
is_gptq
)
{
dim3
blockDim
,
gridDim
;
...
...
@@ -50,7 +51,7 @@ void gemm_half_q_half_cuda_part
fp_gemm_half_q_half_kernel
kernel
=
pick_gemm_half_q_half_kernel
(
m_count
,
r_weights
!=
NULL
,
mul_r_weights
);
kernel
<<<
gridDim
,
blockDim
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
a
,
b
->
cuda_q_weight
,
...
...
@@ -91,7 +92,7 @@ void gemm_half_q_half_cuda_part
// print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);
kernel
<<<
gridDim
,
blockDim
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
a
,
b
->
cuda_q_weight
,
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu
View file @
0d794af6
...
...
@@ -168,8 +168,9 @@ QMatrix::QMatrix
blockDim
.
y
=
1
;
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
1
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
shuffle_kernel
<<<
gridDim
,
blockDim
>>>
(
cuda_q_weight
,
height
,
width
,
rows_8
,
rows_6
,
rows_5
,
rows_4
,
rows_3
,
rows_2
);
shuffle_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
cuda_q_weight
,
height
,
width
,
rows_8
,
rows_6
,
rows_5
,
rows_4
,
rows_3
,
rows_2
);
}
QMatrix
::~
QMatrix
()
...
...
@@ -475,11 +476,12 @@ void QMatrix::reconstruct(half* out)
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
gridDim
.
y
=
DIVIDE
(
height
,
BLOCK_KN_SIZE
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
!
is_gptq
)
{
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
reconstruct_kernel
<<<
gridDim
,
blockDim
>>>
reconstruct_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
cuda_q_weight
,
cuda_q_perm
,
...
...
@@ -502,7 +504,7 @@ void QMatrix::reconstruct(half* out)
else
{
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
*
4
);
reconstruct_gptq_kernel
<<<
gridDim
,
blockDim
>>>
reconstruct_gptq_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
cuda_q_weight
,
cuda_q_perm
,
...
...
@@ -563,6 +565,7 @@ __global__ void make_sequential_kernel
bool
QMatrix
::
make_sequential
(
const
uint32_t
*
cpu_g_idx
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
uint32_t
*
cuda_new_qweight
=
NULL
;
cudaError_t
err
=
cudaMalloc
(
&
cuda_new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
));
if
(
err
!=
cudaSuccess
)
{
...
...
@@ -621,7 +624,7 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
height
/
8
;
make_sequential_kernel
<<<
gridDim
,
blockDim
>>>
make_sequential_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
cuda_q_weight
,
cuda_new_qweight
,
...
...
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
View file @
0d794af6
...
...
@@ -425,6 +425,11 @@ class FlashMistralForCausalLM(torch.nn.Module):
weights
=
weights
,
)
self
.
max_past
=
config
.
sliding_window
self
.
max_past_tensor
=
(
torch
.
tensor
(
config
.
sliding_window
,
device
=
weights
.
device
)
if
self
.
max_past
is
not
None
else
None
)
def
forward
(
self
,
...
...
@@ -446,8 +451,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
elif
self
.
max_past
is
not
None
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s
=
min
(
self
.
max_past
,
max_s
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
...
...
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
View file @
0d794af6
...
...
@@ -816,6 +816,11 @@ class FlashMixtralForCausalLM(torch.nn.Module):
weights
=
weights
,
)
self
.
max_past
=
config
.
sliding_window
self
.
max_past_tensor
=
(
torch
.
tensor
(
config
.
sliding_window
,
device
=
weights
.
device
)
if
self
.
max_past
is
not
None
else
None
)
def
forward
(
self
,
...
...
@@ -837,8 +842,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
elif
self
.
max_past
is
not
None
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s
=
min
(
self
.
max_past
,
max_s
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
0d794af6
import
math
import
os
import
time
import
itertools
import
torch
...
...
@@ -6,6 +7,7 @@ import torch.distributed
import
numpy
as
np
from
loguru
import
logger
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
transformers
import
PreTrainedTokenizerBase
...
...
@@ -31,6 +33,8 @@ from text_generation_server.utils.dist import MEMORY_FRACTION
tracer
=
trace
.
get_tracer
(
__name__
)
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
@
dataclass
class
FlashCausalLMBatch
(
Batch
):
...
...
@@ -62,7 +66,7 @@ class FlashCausalLMBatch(Batch):
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size
block_tables
:
Optional
[
List
[
List
[
int
]]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
# tensor of size [b, max_
total_
seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor
:
Optional
[
torch
.
Tensor
]
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots
:
Optional
[
torch
.
Tensor
]
...
...
@@ -663,6 +667,8 @@ class FlashCausalLM(Model):
self
.
num_kv_heads
=
num_kv_heads
self
.
head_size
=
head_size
self
.
cuda_graphs
=
{}
super
(
FlashCausalLM
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
...
...
@@ -678,7 +684,60 @@ class FlashCausalLM(Model):
def
batch_type
(
self
)
->
Type
[
FlashCausalLMBatch
]:
return
FlashCausalLMBatch
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lengths
=
torch
.
ones
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
*
max_s
block_tables
=
(
torch
.
arange
(
max_bt
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
.
repeat
(
bs
)
.
reshape
((
bs
,
max_bt
))
)
kv_cache
=
get_cache_manager
().
kv_cache
self
.
cuda_graphs
[
bs
]
=
{
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"kv_cache"
:
kv_cache
,
"block_tables"
:
block_tables
,
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
}
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
cuda_graphs
[
bs
][
"graph"
]
=
graph
torch
.
cuda
.
synchronize
()
# Run once outside to warmup
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
MEM_POOL
):
self
.
cuda_graphs
[
bs
][
"logits"
]
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
def
warmup
(
self
,
batch
:
FlashCausalLMBatch
):
# The warmup batch is the biggest batch we could ever receive
torch
.
cuda
.
empty_cache
()
try
:
cache_manager
=
set_cache_manager
(
...
...
@@ -690,6 +749,8 @@ class FlashCausalLM(Model):
self
.
dtype
,
self
.
device
,
)
max_bt
=
batch
.
max_blocks
max_s
=
max_bt
*
get_cache_manager
().
block_size
_
,
batch
,
_
=
self
.
generate_token
(
batch
)
except
torch
.
cuda
.
OutOfMemoryError
as
e
:
raise
RuntimeError
(
...
...
@@ -713,7 +774,8 @@ class FlashCausalLM(Model):
)
num_blocks
=
(
int
(
free_memory
//
total_cache_size
)
# Leave 5% for some wiggle room
int
((
free_memory
*
0.95
)
//
total_cache_size
)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+
cache_manager
.
num_blocks
)
...
...
@@ -731,9 +793,19 @@ class FlashCausalLM(Model):
self
.
device
,
)
if
os
.
getenv
(
"ENABLE_CUDA_GRAPHS"
,
"False"
)
==
"True"
:
try
:
logger
.
info
(
"Experimental support for Cuda Graphs is enabled"
)
# Warmup cuda graphs
for
bs
in
[
1
,
2
,
4
]
+
[
8
*
i
for
i
in
range
(
8
)]:
if
self
.
speculate
is
None
or
self
.
speculate
+
1
<=
bs
:
self
.
cuda_graph_warmup
(
bs
,
max_s
,
max_bt
)
except
Exception
:
logger
.
exception
(
f
"Decode cuda graph warmup failed"
)
return
int
(
num_blocks
*
BLOCK_SIZE
)
def
forward
(
self
,
batch
:
FlashCausalLMBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
def
forward
(
self
,
batch
:
FlashCausalLMBatch
)
->
torch
.
Tensor
:
# Model Forward
if
batch
.
speculative_ids
is
not
None
:
input_ids
=
batch
.
input_ids
...
...
@@ -785,17 +857,48 @@ class FlashCausalLM(Model):
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
lm_head_indices
=
lm_head_indices
,
)
bs
=
input_ids
.
shape
[
0
]
padded_bs
=
bs
if
bs
==
3
:
padded_bs
=
4
elif
3
<
bs
<=
8
:
padded_bs
=
8
elif
bs
>
8
:
padded_bs
=
(
bs
+
7
)
//
8
*
8
# Try to find an associated cuda graph
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
or
batch
.
speculative_ids
is
not
None
:
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
lm_head_indices
=
lm_head_indices
,
)
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph
[
"input_ids"
][:
input_ids
.
shape
[
0
]]
=
input_ids
cuda_graph
[
"position_ids"
][:
position_ids
.
shape
[
0
]]
=
position_ids
cuda_graph
[
"block_tables"
][
:
block_tables
.
shape
[
0
],
:
block_tables
.
shape
[
1
]
]
=
block_tables
cuda_graph
[
"slots"
].
fill_
(
-
1
)
cuda_graph
[
"slots"
][:
slots
.
shape
[
0
]]
=
slots
cuda_graph
[
"input_lengths"
].
zero_
()
cuda_graph
[
"input_lengths"
][:
input_lengths
.
shape
[
0
]]
=
input_lengths
# Replay the graph
cuda_graph
[
"graph"
].
replay
()
# Slice output to the correct shape
return
cuda_graph
[
"logits"
][:
bs
]
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
...
...
server/text_generation_server/models/flash_mistral.py
View file @
0d794af6
...
...
@@ -35,6 +35,8 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW
:
Optional
[
int
]
=
None
SLIDING_WINDOW_BLOCKS
:
Optional
[
int
]
=
None
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
# Adds windowing logic to FlashCausalLMBatch
@
dataclass
...
...
@@ -332,6 +334,8 @@ class BaseFlashMistral(FlashCausalLM):
model
=
model_cls
(
config
,
weights
)
self
.
cuda_graphs
=
{}
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
BaseFlashMistral
,
self
).
__init__
(
model
=
model
,
...
...
@@ -350,6 +354,60 @@ class BaseFlashMistral(FlashCausalLM):
def
batch_type
(
self
)
->
Type
[
FlashMistralBatch
]:
return
FlashMistralBatch
def
cuda_graph_warmup
(
self
,
bs
:
int
,
max_s
:
int
,
max_bt
:
int
):
input_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slots
=
torch
.
arange
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lengths
=
torch
.
ones
(
bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
*
max_s
block_tables
=
(
torch
.
arange
(
max_bt
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
.
repeat
(
bs
)
.
reshape
((
bs
,
max_bt
))
)
kv_cache
=
get_cache_manager
().
kv_cache
self
.
cuda_graphs
[
bs
]
=
{
"input_ids"
:
input_ids
,
"position_ids"
:
position_ids
,
"kv_cache"
:
kv_cache
,
"block_tables"
:
block_tables
,
"slots"
:
slots
,
"input_lengths"
:
input_lengths
,
}
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
cuda_graphs
[
bs
][
"graph"
]
=
graph
torch
.
cuda
.
synchronize
()
# Run once outside to warmup
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
MEM_POOL
):
self
.
cuda_graphs
[
bs
][
"logits"
]
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
None
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
None
,
lm_head_indices
=
None
,
)
torch
.
cuda
.
synchronize
()
def
forward
(
self
,
batch
:
FlashMistralBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
if
batch
.
speculative_ids
is
not
None
:
...
...
@@ -401,21 +459,56 @@ class BaseFlashMistral(FlashCausalLM):
input_lengths
=
batch
.
input_lengths_tensor
max_s
=
batch
.
max_seqlen
lm_head_indices
=
batch
.
prefill_head_indices
logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
return
logits
if
self
.
model
.
max_past
is
not
None
:
max_s
=
min
(
self
.
model
.
max_past
,
max_s
)
bs
=
input_ids
.
shape
[
0
]
padded_bs
=
bs
if
bs
==
3
:
padded_bs
=
4
elif
3
<
bs
<=
8
:
padded_bs
=
8
elif
bs
>
8
:
padded_bs
=
(
bs
+
7
)
//
8
*
8
# Try to find an associated cuda graph
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables
,
slots
=
slots
,
input_lengths
=
input_lengths
,
max_s
=
max_s
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
return
logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph
[
"input_ids"
][:
input_ids
.
shape
[
0
]]
=
input_ids
cuda_graph
[
"position_ids"
][:
position_ids
.
shape
[
0
]]
=
position_ids
cuda_graph
[
"block_tables"
][
:
block_tables
.
shape
[
0
],
:
block_tables
.
shape
[
1
]
]
=
block_tables
cuda_graph
[
"slots"
].
fill_
(
-
1
)
cuda_graph
[
"slots"
][:
slots
.
shape
[
0
]]
=
slots
cuda_graph
[
"input_lengths"
].
zero_
()
cuda_graph
[
"input_lengths"
][:
input_lengths
.
shape
[
0
]]
=
input_lengths
# Replay the graph
cuda_graph
[
"graph"
].
replay
()
# Slice output to the correct shape
return
cuda_graph
[
"logits"
][:
bs
]
class
FlashMistral
(
BaseFlashMistral
):
...
...
server/text_generation_server/utils/weights.py
View file @
0d794af6
...
...
@@ -407,8 +407,9 @@ class Weights:
data
=
json
.
load
(
f
)
self
.
gptq_bits
=
data
[
"quantization_config"
][
"bits"
]
self
.
gptq_groupsize
=
data
[
"quantization_config"
][
"group_size"
]
self
.
gptq_desc_act
=
data
[
"quantization_config"
][
"desc_act"
]
# Order is important here, desc_act is missing on some real models
self
.
quant_method
=
data
[
"quantization_config"
][
"quant_method"
]
self
.
gptq_desc_act
=
data
[
"quantization_config"
][
"desc_act"
]
except
Exception
:
filename
=
"quantize_config.json"
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment