Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
25e8c688
"tests/vscode:/vscode.git/clone" did not exist on "eb09af5e586d0305bc66c9ef104118afc67b5873"
Commit
25e8c688
authored
Apr 24, 2024
by
huangwb
Browse files
first runnable TGI changes on DCU platform
parent
2d0a7173
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
56 additions
and
19 deletions
+56
-19
launcher/src/main.rs
launcher/src/main.rs
+5
-3
server/Makefile
server/Makefile
+4
-2
server/exllama_kernels/exllama_kernels/cu_compat.cuh
server/exllama_kernels/exllama_kernels/cu_compat.cuh
+2
-2
server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh
server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh
+2
-2
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
+30
-0
server/exllamav2_kernels/setup.py
server/exllamav2_kernels/setup.py
+1
-0
server/pyproject.toml
server/pyproject.toml
+4
-4
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+4
-4
server/text_generation_server/utils/flash_attn.py
server/text_generation_server/utils/flash_attn.py
+1
-1
server/text_generation_server/utils/paged_attention.py
server/text_generation_server/utils/paged_attention.py
+3
-1
No files found.
launcher/src/main.rs
View file @
25e8c688
...
...
@@ -1392,9 +1392,11 @@ fn main() -> Result<(), LauncherError> {
vec!
[]
}
_
=>
{
let
cuda_graphs
=
vec!
[
1
,
2
,
4
,
8
,
16
,
32
];
tracing
::
info!
(
"Using default cuda graphs {cuda_graphs:?}"
);
cuda_graphs
// let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
// tracing::info!("Using default cuda graphs {cuda_graphs:?}");
// cuda_graphs
tracing
::
info!
(
"Currently disable cuda graphs by default,may enable in the future"
);
vec!
[]
}
};
...
...
server/Makefile
View file @
25e8c688
...
...
@@ -19,8 +19,10 @@ gen-server:
install
:
gen-server
pip
install
pip
--upgrade
pip
install
-r
requirements_cuda.txt
pip
install
-e
".[bnb, accelerate, quantize, peft, outlines]"
pip
install
-r
requirements_rocm.txt
# pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
pip
install
-e
".[accelerate, quantize, peft, outlines]"
run-dev
:
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation_server/cli.py serve bigscience/bloom-560m
--sharded
...
...
server/exllama_kernels/exllama_kernels/cu_compat.cuh
View file @
25e8c688
...
...
@@ -46,10 +46,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh
View file @
25e8c688
...
...
@@ -44,10 +44,10 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
//
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
...
...
server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu
View file @
25e8c688
...
...
@@ -23,6 +23,36 @@
#include "q_gemm_kernel.cuh"
#include "q_gemm_kernel_gptq.cuh"
#if defined(USE_ROCM)
#include <hipblas/hipblas.h>
__host__
__forceinline__
hipblasStatus_t
__compat_hipblasHgemm
(
hipblasHandle_t
handle
,
hipblasOperation_t
transA
,
hipblasOperation_t
transB
,
int
m
,
int
n
,
int
k
,
const
half
*
alpha
,
const
half
*
AP
,
int
lda
,
const
half
*
BP
,
int
ldb
,
const
half
*
beta
,
half
*
CP
,
int
ldc
)
{
return
hipblasHgemm
(
handle
,
transA
,
transB
,
m
,
n
,
k
,
reinterpret_cast
<
const
hipblasHalf
*>
(
alpha
),
reinterpret_cast
<
const
hipblasHalf
*>
(
AP
),
lda
,
reinterpret_cast
<
const
hipblasHalf
*>
(
BP
),
ldb
,
reinterpret_cast
<
const
hipblasHalf
*>
(
beta
),
reinterpret_cast
<
hipblasHalf
*>
(
CP
),
ldc
);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
void
gemm_half_q_half_cuda_part
(
const
half
*
a
,
...
...
server/exllamav2_kernels/setup.py
View file @
25e8c688
...
...
@@ -6,6 +6,7 @@ extra_cuda_cflags = ["-lineinfo", "-O3"]
if
torch
.
version
.
hip
:
extra_cuda_cflags
+=
[
"-DHIPBLAS_USE_HIP_HALF"
]
extra_cuda_cflags
+=
[
"-DUSE_ROCM"
]
extra_compile_args
=
{
"nvcc"
:
extra_cuda_cflags
,
...
...
server/pyproject.toml
View file @
25e8c688
...
...
@@ -49,10 +49,10 @@ grpcio-tools = "^1.51.1"
pytest
=
"^7.3.0"
[[tool.poetry.source]]
name
=
"pytorch-gpu-src"
url
=
"https://download.pytorch.org/whl/cu121"
priority
=
"explicit"
#
[[tool.poetry.source]]
#
name = "pytorch-gpu-src"
#
url = "https://download.pytorch.org/whl/cu121"
#
priority = "explicit"
[tool.pytest.ini_options]
markers
=
[
"private: marks tests as requiring an admin hf token (deselect with '-m
\"
not private
\"
')"
]
...
...
server/text_generation_server/models/__init__.py
View file @
25e8c688
...
...
@@ -69,10 +69,10 @@ try:
from
text_generation_server.models.idefics
import
IDEFICSSharded
from
text_generation_server.models.llava_next
import
LlavaNext
from
text_generation_server.models.flash_mistral
import
FlashMistral
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
#
from text_generation_server.models.flash_mixtral import FlashMixtral
from
text_generation_server.models.flash_phi
import
FlashPhi
from
text_generation_server.models.flash_starcoder2
import
FlashStarcoder2
from
text_generation_server.models.flash_dbrx
import
FlashDbrx
#
from text_generation_server.models.flash_dbrx import FlashDbrx
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
except
ImportError
as
e
:
...
...
@@ -87,8 +87,8 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
IDEFICSSharded
)
__all__
.
append
(
FlashMistral
)
__all__
.
append
(
FlashMixtral
)
__all__
.
append
(
FlashDbrx
)
#
__all__.append(FlashMixtral)
#
__all__.append(FlashDbrx)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashStarcoder2
)
...
...
server/text_generation_server/utils/flash_attn.py
View file @
25e8c688
...
...
@@ -33,7 +33,7 @@ try:
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f
"or install flash attention v2 with `cd server && make install install-flash-attention-v2
{
architecture_suffix
}
`"
)
if
not
(
is_sm8x
or
is_sm90
):
if
not
(
is_sm8x
or
is_sm90
)
and
IS_CUDA_SYSTEM
:
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported for "
"Flash Attention V2"
...
...
server/text_generation_server/utils/paged_attention.py
View file @
25e8c688
import
torch
from
text_generation_server.utils.import_utils
import
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
from
loguru
import
logger
_PARTITION_SIZE
=
512
...
...
@@ -21,7 +22,8 @@ def reshape_and_cache(
elif
IS_ROCM_SYSTEM
:
from
vllm
import
cache_ops
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
.
int
())
# cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
else
:
raise
ValueError
(
"vllm is not supported on your system"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment