Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
25e8c688
Commit
25e8c688
authored
Apr 24, 2024
by
huangwb
Browse files
first runnable TGI changes on DCU platform
parent
2d0a7173
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
56 additions
and
19 deletions
+56
-19
launcher/src/main.rs
launcher/src/main.rs
+5
-3
server/Makefile
server/Makefile
+4
-2
server/exllama_kernels/exllama_kernels/cu_compat.cuh
server/exllama_kernels/exllama_kernels/cu_compat.cuh
+2
-2
server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh
server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh
+2
-2
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
+30
-0
server/exllamav2_kernels/setup.py
server/exllamav2_kernels/setup.py
+1
-0
server/pyproject.toml
server/pyproject.toml
+4
-4
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+4
-4
server/text_generation_server/utils/flash_attn.py
server/text_generation_server/utils/flash_attn.py
+1
-1
server/text_generation_server/utils/paged_attention.py
server/text_generation_server/utils/paged_attention.py
+3
-1
No files found.
launcher/src/main.rs
View file @
25e8c688
...
@@ -1392,9 +1392,11 @@ fn main() -> Result<(), LauncherError> {
...
@@ -1392,9 +1392,11 @@ fn main() -> Result<(), LauncherError> {
vec!
[]
vec!
[]
}
}
_
=>
{
_
=>
{
let
cuda_graphs
=
vec!
[
1
,
2
,
4
,
8
,
16
,
32
];
// let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
tracing
::
info!
(
"Using default cuda graphs {cuda_graphs:?}"
);
// tracing::info!("Using default cuda graphs {cuda_graphs:?}");
cuda_graphs
// cuda_graphs
tracing
::
info!
(
"Currently disable cuda graphs by default,may enable in the future"
);
vec!
[]
}
}
};
};
...
...
server/Makefile
View file @
25e8c688
...
@@ -19,8 +19,10 @@ gen-server:
...
@@ -19,8 +19,10 @@ gen-server:
install
:
gen-server
install
:
gen-server
pip
install
pip
--upgrade
pip
install
pip
--upgrade
pip
install
-r
requirements_cuda.txt
pip
install
-r
requirements_rocm.txt
pip
install
-e
".[bnb, accelerate, quantize, peft, outlines]"
# pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
pip
install
-e
".[accelerate, quantize, peft, outlines]"
run-dev
:
run-dev
:
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation_server/cli.py serve bigscience/bloom-560m
--sharded
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation_server/cli.py serve bigscience/bloom-560m
--sharded
...
...
server/exllama_kernels/exllama_kernels/cu_compat.cuh
View file @
25e8c688
...
@@ -46,10 +46,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
...
@@ -46,10 +46,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh
View file @
25e8c688
...
@@ -44,10 +44,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
...
@@ -44,10 +44,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
View file @
25e8c688
...
@@ -23,6 +23,36 @@
...
@@ -23,6 +23,36 @@
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh"
#include "q_gemm_kernel_gptq.cuh"
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
__host__
__forceinline__
hipblasStatus_t
__compat_hipblasHgemm
(
hipblasHandle_t
handle
,
hipblasOperation_t
transA
,
hipblasOperation_t
transB
,
int
m
,
int
n
,
int
k
,
const
half
*
alpha
,
const
half
*
AP
,
int
lda
,
const
half
*
BP
,
int
ldb
,
const
half
*
beta
,
half
*
CP
,
int
ldc
)
{
return
hipblasHgemm
(
handle
,
transA
,
transB
,
m
,
n
,
k
,
reinterpret_cast
<
const
hipblasHalf
*>
(
alpha
),
reinterpret_cast
<
const
hipblasHalf
*>
(
AP
),
lda
,
reinterpret_cast
<
const
hipblasHalf
*>
(
BP
),
ldb
,
reinterpret_cast
<
const
hipblasHalf
*>
(
beta
),
reinterpret_cast
<
hipblasHalf
*>
(
CP
),
ldc
);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
void
gemm_half_q_half_cuda_part
void
gemm_half_q_half_cuda_part
(
(
const
half
*
a
,
const
half
*
a
,
...
...
server/exllamav2_kernels/setup.py
View file @
25e8c688
...
@@ -6,6 +6,7 @@ extra_cuda_cflags = ["-lineinfo", "-O3"]
...
@@ -6,6 +6,7 @@ extra_cuda_cflags = ["-lineinfo", "-O3"]
if
torch
.
version
.
hip
:
if
torch
.
version
.
hip
:
extra_cuda_cflags
+=
[
"-DHIPBLAS_USE_HIP_HALF"
]
extra_cuda_cflags
+=
[
"-DHIPBLAS_USE_HIP_HALF"
]
extra_cuda_cflags
+=
[
"-DUSE_ROCM"
]
extra_compile_args
=
{
extra_compile_args
=
{
"nvcc"
:
extra_cuda_cflags
,
"nvcc"
:
extra_cuda_cflags
,
...
...
server/pyproject.toml
View file @
25e8c688
...
@@ -49,10 +49,10 @@ grpcio-tools = "^1.51.1"
...
@@ -49,10 +49,10 @@ grpcio-tools = "^1.51.1"
pytest
=
"^7.3.0"
pytest
=
"^7.3.0"
[[tool.poetry.source]]
#
[[tool.poetry.source]]
name
=
"pytorch-gpu-src"
#
name = "pytorch-gpu-src"
url
=
"https://download.pytorch.org/whl/cu121"
#
url = "https://download.pytorch.org/whl/cu121"
priority
=
"explicit"
#
priority = "explicit"
[tool.pytest.ini_options]
[tool.pytest.ini_options]
markers
=
[
"private: marks tests as requiring an admin hf token (deselect with '-m
\"
not private
\"
')"
]
markers
=
[
"private: marks tests as requiring an admin hf token (deselect with '-m
\"
not private
\"
')"
]
...
...
server/text_generation_server/models/__init__.py
View file @
25e8c688
...
@@ -69,10 +69,10 @@ try:
...
@@ -69,10 +69,10 @@ try:
from
text_generation_server.models.idefics
import
IDEFICSSharded
from
text_generation_server.models.idefics
import
IDEFICSSharded
from
text_generation_server.models.llava_next
import
LlavaNext
from
text_generation_server.models.llava_next
import
LlavaNext
from
text_generation_server.models.flash_mistral
import
FlashMistral
from
text_generation_server.models.flash_mistral
import
FlashMistral
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
#
from text_generation_server.models.flash_mixtral import FlashMixtral
from
text_generation_server.models.flash_phi
import
FlashPhi
from
text_generation_server.models.flash_phi
import
FlashPhi
from
text_generation_server.models.flash_starcoder2
import
FlashStarcoder2
from
text_generation_server.models.flash_starcoder2
import
FlashStarcoder2
from
text_generation_server.models.flash_dbrx
import
FlashDbrx
#
from text_generation_server.models.flash_dbrx import FlashDbrx
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -87,8 +87,8 @@ if FLASH_ATTENTION:
...
@@ -87,8 +87,8 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
IDEFICSSharded
)
__all__
.
append
(
IDEFICSSharded
)
__all__
.
append
(
FlashMistral
)
__all__
.
append
(
FlashMistral
)
__all__
.
append
(
FlashMixtral
)
#
__all__.append(FlashMixtral)
__all__
.
append
(
FlashDbrx
)
#
__all__.append(FlashDbrx)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashStarcoder2
)
__all__
.
append
(
FlashStarcoder2
)
...
...
server/text_generation_server/utils/flash_attn.py
View file @
25e8c688
...
@@ -33,7 +33,7 @@ try:
...
@@ -33,7 +33,7 @@ try:
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f
"or install flash attention v2 with `cd server && make install install-flash-attention-v2
{
architecture_suffix
}
`"
f
"or install flash attention v2 with `cd server && make install install-flash-attention-v2
{
architecture_suffix
}
`"
)
)
if
not
(
is_sm8x
or
is_sm90
):
if
not
(
is_sm8x
or
is_sm90
)
and
IS_CUDA_SYSTEM
:
raise
ImportError
(
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported for "
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported for "
"Flash Attention V2"
"Flash Attention V2"
...
...
server/text_generation_server/utils/paged_attention.py
View file @
25e8c688
import
torch
import
torch
from
text_generation_server.utils.import_utils
import
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
from
text_generation_server.utils.import_utils
import
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
from
loguru
import
logger
_PARTITION_SIZE
=
512
_PARTITION_SIZE
=
512
...
@@ -21,7 +22,8 @@ def reshape_and_cache(
...
@@ -21,7 +22,8 @@ def reshape_and_cache(
elif
IS_ROCM_SYSTEM
:
elif
IS_ROCM_SYSTEM
:
from
vllm
import
cache_ops
from
vllm
import
cache_ops
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
.
int
())
# cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
else
:
else
:
raise
ValueError
(
"vllm is not supported on your system"
)
raise
ValueError
(
"vllm is not supported on your system"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment