Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
bb9e670a
Commit
bb9e670a
authored
Jan 24, 2025
by
xuxzh1
🎱
Browse files
update linear.py
parent
ee9541af
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
25 deletions
+25
-25
server/text_generation_server/layers/linear.py
server/text_generation_server/layers/linear.py
+25
-25
No files found.
server/text_generation_server/layers/linear.py
View file @
bb9e670a
...
...
@@ -3,19 +3,19 @@ from text_generation_server.utils.import_utils import SYSTEM
from
torch.nn
import
functional
as
F
import
os
#
if SYSTEM == "rocm":
#
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
#
"true",
#
"1",
#
)
ROCM_USE_SKINNY_GEMM
=
False
#
if ROCM_USE_SKINNY_GEMM:
#
try:
#
from vllm import _custom_
C
#
except Exception as e:
#
raise ImportError(
#
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
#
)
if
SYSTEM
==
"rocm"
:
ROCM_USE_SKINNY_GEMM
=
os
.
getenv
(
"ROCM_USE_SKINNY_GEMM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
,
)
if
ROCM_USE_SKINNY_GEMM
:
try
:
from
vllm
import
_custom_
ops
except
Exception
as
e
:
raise
ImportError
(
f
"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error:
{
e
}
"
)
class
FastLinear
(
torch
.
nn
.
Module
):
...
...
@@ -91,18 +91,18 @@ class FastLinearROCm(torch.nn.Module):
batched
=
True
m
,
n
,
k
=
weight
.
shape
[
0
],
inp_shape
[
0
],
inp_shape
[
1
]
if
m
>
8
and
n
<=
4
:
out
=
torch
.
empty
(
inp_shape
[
0
],
weight
.
shape
[
0
],
dtype
=
inp
.
dtype
,
device
=
weight
.
device
)
_custom_C
.
wvSpltK
(
weight
,
inp
,
out
,
n
,
self
.
cu_count
)
elif
m
%
4
==
0
and
n
==
1
and
k
<=
8192
:
out
=
torch
.
empty
(
inp_shape
[
0
],
weight
.
shape
[
0
],
dtype
=
inp
.
dtype
,
device
=
weight
.
device
)
_custom_C
.
LLMM1
(
weight
,
inp
,
out
,
4
)
else
:
out
=
F
.
linear
(
inp
,
weight
)
#
if m > 8 and n <= 4:
#
out = torch.empty(
#
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
#
)
#
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
#
elif m % 4 == 0 and n == 1 and k <= 8192:
#
out = torch.empty(
#
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
#
)
#
_custom_C.LLMM1(weight, inp, out, 4)
#
else:
out
=
F
.
linear
(
inp
,
weight
)
if
batched
:
out
.
view
(
*
inp_shape
[:
-
1
],
out
.
shape
[
-
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment