Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
e563983d
Unverified
Commit
e563983d
authored
Jun 25, 2024
by
Wang, Yi
Committed by
GitHub
Jun 25, 2024
Browse files
fix cpu and xpu issue (#2116)
Signed-off-by:
Wang, Yi A
<
yi.a.wang@intel.com
>
parent
9e2fdf57
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
16 additions
and
7 deletions
+16
-7
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+4
-1
server/text_generation_server/models/flash_gpt2.py
server/text_generation_server/models/flash_gpt2.py
+2
-1
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+2
-1
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+2
-1
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+2
-1
server/text_generation_server/models/flash_rw.py
server/text_generation_server/models/flash_rw.py
+2
-1
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+2
-1
No files found.
server/text_generation_server/models/flash_causal_lm.py
View file @
e563983d
...
...
@@ -768,7 +768,10 @@ class FlashCausalLM(Model):
empty_cache
()
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
x
=
BLOCK_SIZE
//
element_size
if
SYSTEM
==
"ipex"
and
device
.
type
==
"xpu"
:
x
=
1
else
:
x
=
BLOCK_SIZE
//
element_size
if
SYSTEM
==
"ipex"
and
device
==
torch
.
device
(
"cpu"
):
self
.
kv_cache
=
[
...
...
server/text_generation_server/models/flash_gpt2.py
View file @
e563983d
...
...
@@ -37,9 +37,10 @@ class FlashGPT2(FlashCausalLM):
elif
SYSTEM
==
"ipex"
:
if
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
device
=
torch
.
device
(
f
"xpu:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
dtype
=
torch
.
b
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashGPT2 is only available on GPU"
)
...
...
server/text_generation_server/models/flash_llama.py
View file @
e563983d
...
...
@@ -37,9 +37,10 @@ class FlashLlama(FlashCausalLM):
elif
SYSTEM
==
"ipex"
:
if
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
device
=
torch
.
device
(
f
"xpu:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
dtype
=
torch
.
b
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
...
...
server/text_generation_server/models/flash_mistral.py
View file @
e563983d
...
...
@@ -41,9 +41,10 @@ class BaseFlashMistral(FlashCausalLM):
elif
SYSTEM
==
"ipex"
:
if
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
device
=
torch
.
device
(
f
"xpu:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
dtype
=
torch
.
b
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashMistral is only available on GPU"
)
...
...
server/text_generation_server/models/flash_neox.py
View file @
e563983d
...
...
@@ -36,9 +36,10 @@ class FlashNeoXSharded(FlashCausalLM):
elif
SYSTEM
==
"ipex"
:
if
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
device
=
torch
.
device
(
f
"xpu:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
dtype
=
torch
.
b
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
...
...
server/text_generation_server/models/flash_rw.py
View file @
e563983d
...
...
@@ -37,9 +37,10 @@ class FlashRWSharded(FlashCausalLM):
elif
SYSTEM
==
"ipex"
:
if
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
device
=
torch
.
device
(
f
"xpu:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
dtype
=
torch
.
b
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashRW is only available on GPU"
)
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
e563983d
...
...
@@ -40,9 +40,10 @@ class FlashSantacoderSharded(FlashCausalLM):
elif
SYSTEM
==
"ipex"
:
if
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
device
=
torch
.
device
(
f
"xpu:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
dtype
=
torch
.
b
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment