Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
efd602c8
Commit
efd602c8
authored
Oct 29, 2024
by
xuxzh1
🎱
Browse files
last
parent
f1b779fc
Changes
214
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3213 additions
and
84 deletions
+3213
-84
server/requirements_cuda.txt
server/requirements_cuda.txt
+10
-10
server/requirements_intel.txt
server/requirements_intel.txt
+48
-0
server/requirements_rocm.txt
server/requirements_rocm.txt
+10
-10
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+1
-0
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+1
-0
server/tests/models/test_model.py
server/tests/models/test_model.py
+6
-1
server/tests/models/test_santacoder.py
server/tests/models/test_santacoder.py
+8
-0
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+1
-0
server/tests/utils/test_layers.py
server/tests/utils/test_layers.py
+1
-1
server/tests/utils/test_weights.py
server/tests/utils/test_weights.py
+1152
-0
server/text_generation_server/adapters/__init__.py
server/text_generation_server/adapters/__init__.py
+13
-0
server/text_generation_server/adapters/config.py
server/text_generation_server/adapters/config.py
+44
-0
server/text_generation_server/adapters/lora.py
server/text_generation_server/adapters/lora.py
+482
-0
server/text_generation_server/adapters/weights.py
server/text_generation_server/adapters/weights.py
+158
-0
server/text_generation_server/cli.py
server/text_generation_server/cli.py
+91
-62
server/text_generation_server/layers/__init__.py
server/text_generation_server/layers/__init__.py
+20
-0
server/text_generation_server/layers/attention/__init__.py
server/text_generation_server/layers/attention/__init__.py
+15
-0
server/text_generation_server/layers/attention/common.py
server/text_generation_server/layers/attention/common.py
+44
-0
server/text_generation_server/layers/attention/cuda.py
server/text_generation_server/layers/attention/cuda.py
+292
-0
server/text_generation_server/layers/attention/flash_attn_triton.py
...t_generation_server/layers/attention/flash_attn_triton.py
+816
-0
No files found.
Too many changes to show.
To preserve performance only
214 of 214+
files are displayed.
Plain diff
Email patch
server/requirements_cuda.txt
View file @
efd602c8
...
@@ -6,14 +6,14 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
...
@@ -6,14 +6,14 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.
3.1
; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.
5.0
; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
2.2
; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
4.0
; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.
19.4
; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.
23.1
; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
...
@@ -32,17 +32,17 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
...
@@ -32,17 +32,17 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.
4.28
; python_version >= "3.9" and python_version < "3.13"
regex==2024.
5.15
; python_version >= "3.9" and python_version < "3.13"
requests==2.3
1.0
; python_version >= "3.9" and python_version < "3.13"
requests==2.3
2.2
; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.
0
; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.
1
; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==
69.5.1
; python_version >= "3.9" and python_version < "3.13"
setuptools==
70.0.0
; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.
2
; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.
4
; python_version >= "3.9" and python_version < "3.13"
transformers==4.4
0
.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.4
1
.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.1
1
.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.1
2
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/requirements_intel.txt
0 → 100644
View file @
efd602c8
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/requirements_rocm.txt
View file @
efd602c8
...
@@ -6,14 +6,14 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
...
@@ -6,14 +6,14 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.
3.1
; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.
5.0
; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
2.2
; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
4.0
; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.
19.4
; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.
23.1
; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
...
@@ -32,17 +32,17 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
...
@@ -32,17 +32,17 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.
4.28
; python_version >= "3.9" and python_version < "3.13"
regex==2024.
5.15
; python_version >= "3.9" and python_version < "3.13"
requests==2.3
1.0
; python_version >= "3.9" and python_version < "3.13"
requests==2.3
2.2
; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.
0
; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.
1
; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==
69.5.1
; python_version >= "3.9" and python_version < "3.13"
setuptools==
70.0.0
; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.
2
; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.
4
; python_version >= "3.9" and python_version < "3.13"
transformers==4.4
0
.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.4
1
.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.1
1
.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.1
2
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/tests/models/test_bloom.py
View file @
efd602c8
...
@@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
...
@@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
return
generate_pb2
.
Request
(
id
=
0
,
id
=
0
,
inputs
=
"Test"
,
inputs
=
"Test"
,
input_chunks
=
generate_pb2
.
Input
(
chunks
=
[
generate_pb2
.
InputChunk
(
text
=
"Test"
)]),
prefill_logprobs
=
True
,
prefill_logprobs
=
True
,
truncate
=
100
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
parameters
=
default_pb_parameters
,
...
...
server/tests/models/test_causal_lm.py
View file @
efd602c8
...
@@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
...
@@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
return
generate_pb2
.
Request
(
id
=
0
,
id
=
0
,
inputs
=
"Test"
,
inputs
=
"Test"
,
input_chunks
=
generate_pb2
.
Input
(
chunks
=
[
generate_pb2
.
InputChunk
(
text
=
"Test"
)]),
prefill_logprobs
=
True
,
prefill_logprobs
=
True
,
truncate
=
100
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
parameters
=
default_pb_parameters
,
...
...
server/tests/models/test_model.py
View file @
efd602c8
...
@@ -17,7 +17,12 @@ def get_test_model():
...
@@ -17,7 +17,12 @@ def get_test_model():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"huggingface/llama-7b"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"huggingface/llama-7b"
)
model
=
TestModel
(
model
=
TestModel
(
torch
.
nn
.
Linear
(
1
,
1
),
tokenizer
,
False
,
torch
.
float32
,
torch
.
device
(
"cpu"
)
"test_model_id"
,
torch
.
nn
.
Linear
(
1
,
1
),
tokenizer
,
False
,
torch
.
float32
,
torch
.
device
(
"cpu"
),
)
)
return
model
return
model
...
...
server/tests/models/test_santacoder.py
View file @
efd602c8
...
@@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
...
@@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
return
generate_pb2
.
Request
(
id
=
0
,
id
=
0
,
inputs
=
"def"
,
inputs
=
"def"
,
input_chunks
=
generate_pb2
.
Input
(
chunks
=
[
generate_pb2
.
InputChunk
(
text
=
"def"
)]),
prefill_logprobs
=
True
,
prefill_logprobs
=
True
,
truncate
=
100
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
parameters
=
default_pb_parameters
,
...
@@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
...
@@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
return
generate_pb2
.
Request
(
id
=
0
,
id
=
0
,
inputs
=
"<fim-prefix>def<fim-suffix>world<fim-middle>"
,
inputs
=
"<fim-prefix>def<fim-suffix>world<fim-middle>"
,
input_chunks
=
generate_pb2
.
Input
(
chunks
=
[
generate_pb2
.
InputChunk
(
text
=
"<fim-prefix>def<fim-suffix>world<fim-middle>"
)
]
),
prefill_logprobs
=
True
,
prefill_logprobs
=
True
,
truncate
=
100
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
parameters
=
default_pb_parameters
,
...
...
server/tests/models/test_seq2seq_lm.py
View file @
efd602c8
...
@@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
...
@@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
return
generate_pb2
.
Request
(
id
=
0
,
id
=
0
,
inputs
=
"Test"
,
inputs
=
"Test"
,
input_chunks
=
generate_pb2
.
Input
(
chunks
=
[
generate_pb2
.
InputChunk
(
text
=
"Test"
)]),
prefill_logprobs
=
True
,
prefill_logprobs
=
True
,
truncate
=
100
,
truncate
=
100
,
parameters
=
default_pb_parameters
,
parameters
=
default_pb_parameters
,
...
...
server/tests/utils/test_layers.py
View file @
efd602c8
import
torch
import
torch
from
text_generation_server.
utils.
layers
import
(
from
text_generation_server.layers
import
(
TensorParallelEmbedding
,
TensorParallelEmbedding
,
)
)
...
...
server/tests/utils/test_weights.py
0 → 100644
View file @
efd602c8
import
pytest
import
torch
from
text_generation_server.utils.weights
import
Weights
from
text_generation_server.layers.gptq
import
GPTQWeight
from
text_generation_server.layers.exl2
import
Exl2Weight
from
text_generation_server.layers.marlin
import
MarlinWeight
from
types
import
SimpleNamespace
from
typing
import
List
,
Optional
,
Dict
,
Union
from
pathlib
import
Path
dummy_file_system
=
{
"test_weights"
:
{
"layer.0.weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
],
dtype
=
torch
.
float32
,
),
},
"test_weights_2"
:
{
"layer.1337.weight"
:
torch
.
tensor
(
[
[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
],
dtype
=
torch
.
float32
,
),
},
"test_get_weights_col_packed"
:
{
"weight.weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
},
"test_get_multi_weights_col"
:
{
"weight.weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
"weight.weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
},
"test_get_multi_weights_row"
:
{
"weight.weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
},
"test_get_weights_col_gptq"
:
{
"weight.qweight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
"weight.g_idx"
:
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
int32
),
"weight.qzeros"
:
torch
.
tensor
(
[
[
0
,
1
],
[
1
,
0
],
],
dtype
=
torch
.
int32
,
),
"weight.scales"
:
torch
.
tensor
(
[
[
100.0
,
100.0
],
[
100.0
,
100.0
],
],
dtype
=
torch
.
float16
,
),
"gptq_bits"
:
torch
.
tensor
([
8
],
dtype
=
torch
.
float32
),
"gptq_groupsize"
:
torch
.
tensor
([
2
],
dtype
=
torch
.
float32
),
},
"test_get_weights_col_marlin"
:
{
"weight.B"
:
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
torch
.
int32
),
"weight.s"
:
torch
.
tensor
([[
0.5000
],
[
0.2500
]],
dtype
=
torch
.
float16
),
},
"test_get_multi_weights_row_gptq"
:
{
"weight.qweight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
int32
,
),
"weight.g_idx"
:
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
int32
),
"weight.qzeros"
:
torch
.
tensor
(
[
[
0
,
1
],
[
1
,
0
],
],
dtype
=
torch
.
int32
,
),
"weight.scales"
:
torch
.
tensor
(
[
[
100.0
,
100.0
],
[
100.0
,
100.0
],
],
dtype
=
torch
.
float16
,
),
"gptq_bits"
:
torch
.
tensor
([
8
],
dtype
=
torch
.
float32
),
"gptq_groupsize"
:
torch
.
tensor
([
2
],
dtype
=
torch
.
float32
),
},
"test_get_multi_weights_col_gptq"
:
{
"weight.qweight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
int32
,
),
"weight.g_idx"
:
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
int32
),
"weight.qzeros"
:
torch
.
tensor
(
[
[
0
,
1
],
[
1
,
0
],
],
dtype
=
torch
.
int32
,
),
"weight.scales"
:
torch
.
tensor
(
[
[
100.0
,
100.0
],
[
100.0
,
100.0
],
],
dtype
=
torch
.
float16
,
),
"gptq_bits"
:
torch
.
tensor
([
8
],
dtype
=
torch
.
float32
),
"gptq_groupsize"
:
torch
.
tensor
([
2
],
dtype
=
torch
.
float32
),
},
"test_get_weights_col_packed_gptq"
:
{
"weight.qweight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
int32
,
),
"weight.g_idx"
:
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
int32
),
"weight.qzeros"
:
torch
.
tensor
(
[
[
0
,
1
],
[
1
,
0
],
],
dtype
=
torch
.
int32
,
),
"weight.scales"
:
torch
.
tensor
(
[
[
100.0
,
100.0
],
[
100.0
,
100.0
],
],
dtype
=
torch
.
float16
,
),
"gptq_bits"
:
torch
.
tensor
([
8
],
dtype
=
torch
.
float32
),
"gptq_groupsize"
:
torch
.
tensor
([
2
],
dtype
=
torch
.
float32
),
},
"test_get_weights_col_packed_exl2"
:
{
"weight.q_weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
int32
,
),
"weight.q_scale"
:
torch
.
tensor
([
8
],
dtype
=
torch
.
int32
),
"weight.q_invperm"
:
torch
.
tensor
([
1
,
0
,
3
,
2
],
dtype
=
torch
.
int32
),
"weight.q_scale_max"
:
torch
.
tensor
([
100
],
dtype
=
torch
.
float16
),
"weight.q_groups"
:
torch
.
tensor
([
4
],
dtype
=
torch
.
int16
),
},
"test_get_multi_weights_row_exl2"
:
{
"weight.q_weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
int32
,
),
"weight.q_scale"
:
torch
.
tensor
([
8
],
dtype
=
torch
.
int32
),
"weight.q_invperm"
:
torch
.
tensor
([
1
,
0
,
3
,
2
],
dtype
=
torch
.
int32
),
"weight.q_scale_max"
:
torch
.
tensor
([
100
],
dtype
=
torch
.
float16
),
"weight.q_groups"
:
torch
.
tensor
([
4
],
dtype
=
torch
.
int16
),
},
"test_get_multi_weights_col_exl2"
:
{
"weight.q_weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
int32
,
),
"weight.q_scale"
:
torch
.
tensor
([
8
],
dtype
=
torch
.
int32
),
"weight.q_invperm"
:
torch
.
tensor
([
1
,
0
,
3
,
2
],
dtype
=
torch
.
int32
),
"weight.q_scale_max"
:
torch
.
tensor
([
100
],
dtype
=
torch
.
float16
),
"weight.q_groups"
:
torch
.
tensor
([
4
],
dtype
=
torch
.
int16
),
},
"test_get_weights_col_exl2"
:
{
"weight.q_weight"
:
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
int32
,
),
"weight.q_scale"
:
torch
.
tensor
([
8
],
dtype
=
torch
.
int32
),
"weight.q_invperm"
:
torch
.
tensor
([
1
,
0
,
3
,
2
],
dtype
=
torch
.
int32
),
"weight.q_scale_max"
:
torch
.
tensor
([
100
],
dtype
=
torch
.
float16
),
"weight.q_groups"
:
torch
.
tensor
([
4
],
dtype
=
torch
.
int16
),
},
"test_get_multi_weights_row_marlin"
:
{
"weight.B"
:
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
torch
.
int32
),
"weight.s"
:
torch
.
tensor
([[
0.5
],
[
0.25
]],
dtype
=
torch
.
float16
),
},
"test_get_multi_weights_col_marlin"
:
{
"weight.B"
:
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
torch
.
int32
),
"weight.s"
:
torch
.
tensor
([[
0.5
],
[
0.25
]],
dtype
=
torch
.
float16
),
},
"test_get_weights_col_packed_marlin"
:
{
"weight.B"
:
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
torch
.
int32
),
"weight.s"
:
torch
.
tensor
([[
0.5
],
[
0.25
]],
dtype
=
torch
.
float16
),
},
}
class
MockSlice
:
def
__init__
(
self
,
tensor
):
self
.
tensor
=
tensor
def
get_shape
(
self
):
return
self
.
tensor
.
shape
def
__getitem__
(
self
,
idx
):
return
self
.
tensor
[
idx
]
def
mock_get_slice
(
tensor_name
,
filename
):
tensor
=
dummy_file_system
[
filename
][
tensor_name
]
return
MockSlice
(
tensor
)
def
mock_handle
(
filename
,
device
,
dtype
):
return
SimpleNamespace
(
get_slice
=
lambda
tensor_name
:
mock_get_slice
(
tensor_name
,
filename
)
)
class
MockSafeOpen
:
def
__init__
(
self
,
filename
,
framework
,
dummy_fs
):
self
.
filename
=
filename
self
.
framework
=
framework
self
.
dummy_fs
=
dummy_fs
def
keys
(
self
):
return
list
(
self
.
dummy_fs
[
self
.
filename
].
keys
())
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
pass
class
MockWeights
(
Weights
):
def
__init__
(
self
,
filenames
:
List
[
Union
[
Path
,
str
]],
device
,
dtype
,
process_group
,
dummy_fs
,
aliases
:
Optional
[
Dict
[
str
,
List
[
str
]]]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
):
routing
=
{}
self
.
dummy_fs
=
dummy_fs
for
filename
in
filenames
:
with
MockSafeOpen
(
filename
,
framework
=
"pytorch"
,
dummy_fs
=
dummy_fs
)
as
f
:
for
k
in
f
.
keys
():
if
k
in
routing
:
raise
RuntimeError
(
f
"Key
{
k
}
was found in multiple files:
{
filename
}
and
{
routing
[
k
]
}
"
)
routing
[
k
]
=
filename
if
aliases
is
None
:
aliases
=
{}
self
.
aliases
=
aliases
self
.
routing
=
routing
self
.
device
=
device
self
.
dtype
=
dtype
self
.
process_group
=
process_group
self
.
prefix
=
prefix
self
.
_handles
=
{}
def
_get_handle
(
self
,
filename
:
Union
[
Path
,
str
]):
if
filename
in
self
.
_handles
:
return
self
.
_handles
[
filename
]
else
:
handle
=
mock_handle
(
filename
,
self
.
device
,
self
.
dtype
)
self
.
_handles
[
filename
]
=
handle
return
handle
def
get_shape
(
self
,
tensor_name
:
str
):
filename
,
_
=
self
.
get_filename
(
tensor_name
)
handle
=
self
.
_get_handle
(
filename
)
return
handle
.
get_slice
(
tensor_name
).
get_shape
()
def
get_tensor
(
self
,
tensor_name
:
str
):
filename
,
_
=
self
.
get_filename
(
tensor_name
)
handle
=
self
.
_get_handle
(
filename
)
return
handle
.
get_slice
(
tensor_name
).
tensor
dummy_process_group
=
SimpleNamespace
(
rank
=
lambda
:
0
,
size
=
lambda
:
1
)
def
test_weights
():
weights
=
MockWeights
(
[
"test_weights"
,
"test_weights_2"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
assert
weights
.
get_shape
(
"layer.0.weight"
)
==
(
2
,
2
)
assert
weights
.
get_tensor
(
"layer.1337.weight"
).
shape
==
(
2
,
4
)
def
test_get_tensor
():
weights
=
MockWeights
(
[
"test_weights"
,
"test_weights_2"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
assert
torch
.
allclose
(
weights
.
get_tensor
(
"layer.0.weight"
),
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
],
dtype
=
torch
.
float32
,
),
)
assert
torch
.
allclose
(
weights
.
get_tensor
(
"layer.1337.weight"
),
torch
.
tensor
(
[
[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
],
dtype
=
torch
.
float32
,
),
)
def
test_get_weights_col_packed
():
weights
=
MockWeights
(
[
"test_get_weights_col_packed"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
None
block_sizes
=
1
w
=
weights
.
get_weights_col_packed
(
prefix
=
prefix
,
quantize
=
quantize
,
block_sizes
=
block_sizes
,
)
assert
torch
.
allclose
(
w
,
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
)
def
test_get_weights_col_packed_block_size
():
weights
=
MockWeights
(
[
"test_get_weights_col_packed"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
None
block_sizes
=
2
w
=
weights
.
get_weights_col_packed
(
prefix
=
prefix
,
quantize
=
quantize
,
block_sizes
=
block_sizes
,
)
assert
torch
.
allclose
(
w
,
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
)
def
test_get_weights_col_packed_block_size_arr
():
weights
=
MockWeights
(
[
"test_get_weights_col_packed"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
None
block_sizes
=
[
1
,
1
]
w
=
weights
.
get_weights_col_packed
(
prefix
=
prefix
,
quantize
=
quantize
,
block_sizes
=
block_sizes
,
)
assert
torch
.
allclose
(
w
,
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
)
def
test_get_multi_weights_col
():
weights
=
MockWeights
(
[
"test_get_multi_weights_col"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefixes
=
[
"weight"
,
"weight"
]
quantize
=
None
w
=
weights
.
get_multi_weights_col
(
prefixes
=
prefixes
,
quantize
=
quantize
,
dim
=
0
,
)
assert
torch
.
allclose
(
w
,
torch
.
tensor
(
[
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
],
dtype
=
torch
.
float32
,
),
)
def
test_get_multi_weights_row
():
weights
=
MockWeights
(
[
"test_get_multi_weights_row"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
None
w
=
weights
.
get_multi_weights_row
(
prefix
=
prefix
,
quantize
=
quantize
,
)
assert
torch
.
allclose
(
w
,
torch
.
tensor
(
[[
1.0
,
2.0
],
[
3.0
,
4.0
],
[
5.0
,
6.0
],
[
7.0
,
8.0
]],
dtype
=
torch
.
float32
,
),
)
# test_get_weights_col
def
test_get_weights_col_awq
():
weights
=
MockWeights
(
[
"test_get_weights_col_gptq"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"awq"
w
=
weights
.
get_weights_col
(
prefix
=
prefix
,
quantize
=
quantize
,
)
expected_weight
=
GPTQWeight
(
qweight
=
torch
.
tensor
([[
1.0
,
2.0
],
[
3.0
,
4.0
],
[
5.0
,
6.0
],
[
7.0
,
8.0
]]),
qzeros
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
int32
),
scales
=
torch
.
tensor
(
[[
100.0
,
100.0
],
[
100.0
,
100.0
]],
dtype
=
torch
.
float16
,
),
g_idx
=
None
,
bits
=
8.0
,
groupsize
=
2.0
,
use_exllama
=
False
,
)
assert
torch
.
allclose
(
w
.
qweight
,
expected_weight
.
qweight
),
"qweight mismatch"
assert
torch
.
allclose
(
w
.
qzeros
,
expected_weight
.
qzeros
),
"qzeros mismatch"
assert
torch
.
allclose
(
w
.
scales
,
expected_weight
.
scales
),
"scales mismatch"
assert
w
.
g_idx
==
expected_weight
.
g_idx
,
"g_idx mismatch"
assert
w
.
bits
==
expected_weight
.
bits
,
"bits mismatch"
assert
w
.
groupsize
==
expected_weight
.
groupsize
,
"groupsize mismatch"
assert
w
.
use_exllama
==
expected_weight
.
use_exllama
,
"use_exllama mismatch"
def
test_get_weights_col_gtpq
():
weights
=
MockWeights
(
[
"test_get_weights_col_gptq"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"gptq"
w
=
weights
.
get_weights_col
(
prefix
=
prefix
,
quantize
=
quantize
,
)
expected_weight
=
GPTQWeight
(
qweight
=
torch
.
tensor
([[
1.0
,
2.0
],
[
3.0
,
4.0
],
[
5.0
,
6.0
],
[
7.0
,
8.0
]]),
qzeros
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
int32
),
scales
=
torch
.
tensor
([[
100.0
,
100.0
],
[
100.0
,
100.0
]],
dtype
=
torch
.
float16
),
g_idx
=
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
int32
),
bits
=
8.0
,
groupsize
=
2.0
,
use_exllama
=
False
,
)
assert
torch
.
allclose
(
w
.
qweight
,
expected_weight
.
qweight
),
"qweight mismatch"
assert
torch
.
allclose
(
w
.
qzeros
,
expected_weight
.
qzeros
),
"qzeros mismatch"
assert
torch
.
allclose
(
w
.
scales
,
expected_weight
.
scales
),
"scales mismatch"
assert
torch
.
allclose
(
w
.
g_idx
,
expected_weight
.
g_idx
),
"g_idx mismatch"
assert
w
.
bits
==
expected_weight
.
bits
,
"bits mismatch"
assert
w
.
groupsize
==
expected_weight
.
groupsize
,
"groupsize mismatch"
assert
w
.
use_exllama
==
expected_weight
.
use_exllama
,
"use_exllama mismatch"
def
test_get_weights_col_exl2
():
weights
=
MockWeights
(
[
"test_get_weights_col_exl2"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"exl2"
w
=
weights
.
get_weights_col
(
prefix
=
prefix
,
quantize
=
quantize
,
)
scaled_scale_max
=
0.3906
*
256
expected_weight
=
Exl2Weight
(
q_weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
q_scale
=
torch
.
tensor
([
8
],
dtype
=
torch
.
int32
),
q_invperm
=
torch
.
tensor
([
1
,
0
,
3
,
2
],
dtype
=
torch
.
int16
),
q_scale_max
=
torch
.
tensor
([
scaled_scale_max
],
dtype
=
torch
.
float16
),
q_groups
=
torch
.
tensor
([
4
],
dtype
=
torch
.
int16
),
)
assert
torch
.
allclose
(
w
.
q_weight
,
expected_weight
.
q_weight
),
"q_weight mismatch"
assert
torch
.
allclose
(
w
.
q_scale
,
expected_weight
.
q_scale
),
"q_scale mismatch"
assert
torch
.
allclose
(
w
.
q_invperm
,
expected_weight
.
q_invperm
),
"q_invperm mismatch"
assert
torch
.
allclose
(
w
.
q_scale_max
,
expected_weight
.
q_scale_max
),
"q_scale_max mismatch"
assert
torch
.
allclose
(
w
.
q_groups
,
expected_weight
.
q_groups
),
"q_groups mismatch"
def
test_get_weights_col_marlin
():
weights
=
MockWeights
(
[
"test_get_weights_col_marlin"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float16
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"marlin"
w
=
weights
.
get_weights_col
(
prefix
=
prefix
,
quantize
=
quantize
,
)
expected_weight
=
MarlinWeight
(
B
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
torch
.
int32
),
s
=
torch
.
tensor
([[
0.5000
],
[
0.2500
]],
dtype
=
torch
.
float16
),
)
assert
torch
.
allclose
(
w
.
B
,
expected_weight
.
B
),
"B mismatch"
assert
torch
.
allclose
(
w
.
s
,
expected_weight
.
s
),
"s mismatch"
# test_get_weights_col_packed
def
test_get_weights_col_packed_awq
():
weights
=
MockWeights
(
[
"test_get_weights_col_packed_gptq"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"awq"
block_sizes
=
1
w
=
weights
.
get_weights_col_packed
(
prefix
=
prefix
,
quantize
=
quantize
,
block_sizes
=
block_sizes
,
)
expected_weight
=
GPTQWeight
(
qweight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
qzeros
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
int32
),
scales
=
torch
.
tensor
([[
100.0
,
100.0
],
[
100.0
,
100.0
]],
dtype
=
torch
.
float16
),
g_idx
=
None
,
bits
=
8.0
,
groupsize
=
2.0
,
use_exllama
=
False
,
)
assert
torch
.
allclose
(
w
.
qweight
,
expected_weight
.
qweight
),
"qweight mismatch"
assert
torch
.
allclose
(
w
.
qzeros
,
expected_weight
.
qzeros
),
"qzeros mismatch"
assert
torch
.
allclose
(
w
.
scales
,
expected_weight
.
scales
),
"scales mismatch"
assert
w
.
g_idx
==
expected_weight
.
g_idx
,
"g_idx mismatch"
assert
w
.
bits
==
expected_weight
.
bits
,
"bits mismatch"
assert
w
.
groupsize
==
expected_weight
.
groupsize
,
"groupsize mismatch"
assert
w
.
use_exllama
==
expected_weight
.
use_exllama
,
"use_exllama mismatch"
@
pytest
.
mark
.
skip
(
reason
=
"Review expected functionality"
)
def
test_get_weights_col_packed_exl2
():
weights
=
MockWeights
(
[
"test_get_weights_col_packed_exl2"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"exl2"
block_sizes
=
1
w
=
weights
.
get_weights_col_packed
(
prefix
=
prefix
,
quantize
=
quantize
,
block_sizes
=
block_sizes
,
)
scaled_scale_max
=
0.3906
*
256
expected_weight
=
Exl2Weight
(
q_weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
q_scale
=
torch
.
tensor
([
8
],
dtype
=
torch
.
int32
),
q_invperm
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int16
),
q_scale_max
=
torch
.
tensor
([
scaled_scale_max
],
dtype
=
torch
.
float16
),
q_groups
=
torch
.
tensor
([
4
],
dtype
=
torch
.
int16
),
)
assert
torch
.
allclose
(
w
.
q_weight
,
expected_weight
.
q_weight
),
"q_weight mismatch"
assert
torch
.
allclose
(
w
.
q_scale
,
expected_weight
.
q_scale
),
"q_scale mismatch"
assert
torch
.
allclose
(
w
.
q_invperm
,
expected_weight
.
q_invperm
),
"q_invperm mismatch"
assert
torch
.
allclose
(
w
.
q_scale_max
,
expected_weight
.
q_scale_max
),
"q_scale_max mismatch"
assert
torch
.
allclose
(
w
.
q_groups
,
expected_weight
.
q_groups
),
"q_groups mismatch"
def
test_get_weights_col_packed_gptq
():
weights
=
MockWeights
(
[
"test_get_weights_col_packed_gptq"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefixes
=
[
"weight"
]
quantize
=
"gptq"
w
=
weights
.
get_multi_weights_col
(
prefixes
=
prefixes
,
quantize
=
quantize
,
dim
=
0
,
)
expected_weight
=
GPTQWeight
(
qweight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
qzeros
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
int32
),
scales
=
torch
.
tensor
([[
100.0
,
100.0
],
[
100.0
,
100.0
]],
dtype
=
torch
.
float16
),
g_idx
=
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
int32
),
bits
=
8.0
,
groupsize
=
2.0
,
use_exllama
=
False
,
)
assert
torch
.
allclose
(
w
.
qweight
,
expected_weight
.
qweight
),
"qweight mismatch"
assert
torch
.
allclose
(
w
.
qzeros
,
expected_weight
.
qzeros
),
"qzeros mismatch"
assert
torch
.
allclose
(
w
.
scales
,
expected_weight
.
scales
),
"scales mismatch"
assert
torch
.
allclose
(
w
.
g_idx
,
expected_weight
.
g_idx
),
"g_idx mismatch"
assert
w
.
bits
==
expected_weight
.
bits
,
"bits mismatch"
assert
w
.
groupsize
==
expected_weight
.
groupsize
,
"groupsize mismatch"
assert
w
.
use_exllama
==
expected_weight
.
use_exllama
,
"use_exllama mismatch"
def
test_get_weights_col_packed_marlin
():
weights
=
MockWeights
(
[
"test_get_weights_col_packed_marlin"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float16
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"marlin"
w
=
weights
.
get_multi_weights_col
(
prefixes
=
[
prefix
],
quantize
=
quantize
,
dim
=
0
,
)
expected_weight
=
MarlinWeight
(
B
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
torch
.
int32
),
s
=
torch
.
tensor
([[
0.5000
],
[
0.2500
]],
dtype
=
torch
.
float16
),
)
print
(
expected_weight
)
assert
torch
.
allclose
(
w
.
B
,
expected_weight
.
B
),
"B mismatch"
assert
torch
.
allclose
(
w
.
s
,
expected_weight
.
s
),
"s mismatch"
# test_get_multi_weights_col
def
test_get_multi_weights_col_awq
():
weights
=
MockWeights
(
[
"test_get_multi_weights_col_gptq"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefixes
=
[
"weight"
]
quantize
=
"awq"
w
=
weights
.
get_multi_weights_col
(
prefixes
=
prefixes
,
quantize
=
quantize
,
dim
=
0
,
)
expected_weight
=
GPTQWeight
(
qweight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
qzeros
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
int32
),
scales
=
torch
.
tensor
([[
100.0
,
100.0
],
[
100.0
,
100.0
]],
dtype
=
torch
.
float16
),
g_idx
=
None
,
bits
=
8.0
,
groupsize
=
2.0
,
use_exllama
=
False
,
)
assert
torch
.
allclose
(
w
.
qweight
,
expected_weight
.
qweight
),
"qweight mismatch"
assert
torch
.
allclose
(
w
.
qzeros
,
expected_weight
.
qzeros
),
"qzeros mismatch"
assert
torch
.
allclose
(
w
.
scales
,
expected_weight
.
scales
),
"scales mismatch"
assert
w
.
g_idx
==
expected_weight
.
g_idx
,
"g_idx mismatch"
assert
w
.
bits
==
expected_weight
.
bits
,
"bits mismatch"
assert
w
.
groupsize
==
expected_weight
.
groupsize
,
"groupsize mismatch"
assert
w
.
use_exllama
==
expected_weight
.
use_exllama
,
"use_exllama mismatch"
def
test_get_multi_weights_col_exl2
():
weights
=
MockWeights
(
[
"test_get_multi_weights_col_exl2"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"exl2"
try
:
w
=
weights
.
get_multi_weights_col
(
prefixes
=
[
prefix
],
quantize
=
quantize
,
dim
=
0
,
)
except
ValueError
as
e
:
assert
e
.
args
[
0
]
==
"get_multi_weights_col is not supported for exl2"
def
test_get_multi_weights_col_gptq
():
weights
=
MockWeights
(
[
"test_get_multi_weights_col_gptq"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefixes
=
[
"weight"
]
quantize
=
"gptq"
w
=
weights
.
get_multi_weights_col
(
prefixes
=
prefixes
,
quantize
=
quantize
,
dim
=
0
,
)
expected_weight
=
GPTQWeight
(
qweight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
qzeros
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
int32
),
scales
=
torch
.
tensor
([[
100.0
,
100.0
],
[
100.0
,
100.0
]],
dtype
=
torch
.
float16
),
g_idx
=
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
int32
),
bits
=
8.0
,
groupsize
=
2.0
,
use_exllama
=
False
,
)
assert
torch
.
allclose
(
w
.
qweight
,
expected_weight
.
qweight
),
"qweight mismatch"
assert
torch
.
allclose
(
w
.
qzeros
,
expected_weight
.
qzeros
),
"qzeros mismatch"
assert
torch
.
allclose
(
w
.
scales
,
expected_weight
.
scales
),
"scales mismatch"
assert
torch
.
allclose
(
w
.
g_idx
,
expected_weight
.
g_idx
),
"g_idx mismatch"
assert
w
.
bits
==
expected_weight
.
bits
,
"bits mismatch"
assert
w
.
groupsize
==
expected_weight
.
groupsize
,
"groupsize mismatch"
assert
w
.
use_exllama
==
expected_weight
.
use_exllama
,
"use_exllama mismatch"
def
test_get_multi_weights_col_marlin
():
weights
=
MockWeights
(
[
"test_get_multi_weights_col_marlin"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float16
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"marlin"
w
=
weights
.
get_multi_weights_col
(
prefixes
=
[
prefix
],
quantize
=
quantize
,
dim
=
0
,
)
expected_weight
=
MarlinWeight
(
B
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
torch
.
int32
),
s
=
torch
.
tensor
([[
0.5000
],
[
0.2500
]],
dtype
=
torch
.
float16
),
)
assert
torch
.
allclose
(
w
.
B
,
expected_weight
.
B
),
"B mismatch"
assert
torch
.
allclose
(
w
.
s
,
expected_weight
.
s
),
"s mismatch"
# test_get_multi_weights_row
def
test_get_multi_weights_row_awq
():
weights
=
MockWeights
(
[
"test_get_multi_weights_row_gptq"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"awq"
w
=
weights
.
get_multi_weights_row
(
prefix
=
prefix
,
quantize
=
quantize
,
)
expected_weight
=
GPTQWeight
(
qweight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
qzeros
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
int32
),
scales
=
torch
.
tensor
([[
100.0
,
100.0
],
[
100.0
,
100.0
]],
dtype
=
torch
.
float16
),
g_idx
=
None
,
bits
=
8.0
,
groupsize
=
2.0
,
use_exllama
=
False
,
)
assert
torch
.
allclose
(
w
.
qweight
,
expected_weight
.
qweight
),
"qweight mismatch"
assert
torch
.
allclose
(
w
.
qzeros
,
expected_weight
.
qzeros
),
"qzeros mismatch"
assert
torch
.
allclose
(
w
.
scales
,
expected_weight
.
scales
),
"scales mismatch"
assert
w
.
g_idx
==
expected_weight
.
g_idx
,
"g_idx mismatch"
assert
w
.
bits
==
expected_weight
.
bits
,
"bits mismatch"
assert
w
.
groupsize
==
expected_weight
.
groupsize
,
"groupsize mismatch"
assert
w
.
use_exllama
==
expected_weight
.
use_exllama
,
"use_exllama mismatch"
def
test_get_multi_weights_row_exl2
():
weights
=
MockWeights
(
[
"test_get_multi_weights_row_exl2"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"exl2"
w
=
weights
.
get_multi_weights_row
(
prefix
=
prefix
,
quantize
=
quantize
,
)
print
(
w
)
scaled_scale_max
=
0.3906
*
256
expected_weight
=
Exl2Weight
(
q_weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
q_scale
=
torch
.
tensor
([
8
],
dtype
=
torch
.
int32
),
q_invperm
=
torch
.
tensor
([
1
,
0
,
3
,
2
],
dtype
=
torch
.
int16
),
q_scale_max
=
torch
.
tensor
([
scaled_scale_max
],
dtype
=
torch
.
float16
),
q_groups
=
torch
.
tensor
([
4
],
dtype
=
torch
.
int16
),
)
assert
torch
.
allclose
(
w
.
q_weight
,
expected_weight
.
q_weight
),
"q_weight mismatch"
assert
torch
.
allclose
(
w
.
q_scale
,
expected_weight
.
q_scale
),
"q_scale mismatch"
assert
torch
.
allclose
(
w
.
q_invperm
,
expected_weight
.
q_invperm
),
"q_invperm mismatch"
assert
torch
.
allclose
(
w
.
q_scale_max
,
expected_weight
.
q_scale_max
),
"q_scale_max mismatch"
assert
torch
.
allclose
(
w
.
q_groups
,
expected_weight
.
q_groups
),
"q_groups mismatch"
def
test_get_multi_weights_row_gptq
():
weights
=
MockWeights
(
[
"test_get_multi_weights_row_gptq"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float32
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"gptq"
w
=
weights
.
get_multi_weights_row
(
prefix
=
prefix
,
quantize
=
quantize
,
)
expected_weight
=
GPTQWeight
(
qweight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
dtype
=
torch
.
int32
),
qzeros
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
int32
),
scales
=
torch
.
tensor
([[
100.0
,
100.0
],
[
100.0
,
100.0
]],
dtype
=
torch
.
float16
),
g_idx
=
torch
.
tensor
([
0
,
1
,
0
,
1
],
dtype
=
torch
.
int32
),
bits
=
8.0
,
groupsize
=
2.0
,
use_exllama
=
False
,
)
assert
torch
.
allclose
(
w
.
qweight
,
expected_weight
.
qweight
),
"qweight mismatch"
assert
torch
.
allclose
(
w
.
qzeros
,
expected_weight
.
qzeros
),
"qzeros mismatch"
assert
torch
.
allclose
(
w
.
scales
,
expected_weight
.
scales
),
"scales mismatch"
assert
torch
.
allclose
(
w
.
g_idx
,
expected_weight
.
g_idx
),
"g_idx mismatch"
assert
w
.
bits
==
expected_weight
.
bits
,
"bits mismatch"
assert
w
.
groupsize
==
expected_weight
.
groupsize
,
"groupsize mismatch"
assert
w
.
use_exllama
==
expected_weight
.
use_exllama
,
"use_exllama mismatch"
def
test_get_multi_weights_row_marlin
():
weights
=
MockWeights
(
[
"test_get_multi_weights_row_marlin"
,
],
device
=
"cpu"
,
dtype
=
torch
.
float16
,
process_group
=
dummy_process_group
,
dummy_fs
=
dummy_file_system
,
)
prefix
=
"weight"
quantize
=
"marlin"
w
=
weights
.
get_multi_weights_row
(
prefix
=
prefix
,
quantize
=
quantize
,
)
expected_weight
=
MarlinWeight
(
B
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]],
dtype
=
torch
.
int32
),
s
=
torch
.
tensor
([[
0.5000
],
[
0.2500
]],
dtype
=
torch
.
float16
),
)
assert
torch
.
allclose
(
w
.
B
,
expected_weight
.
B
),
"B mismatch"
assert
torch
.
allclose
(
w
.
s
,
expected_weight
.
s
),
"s mismatch"
server/text_generation_server/adapters/__init__.py
0 → 100644
View file @
efd602c8
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004
from
text_generation_server.adapters.weights
import
(
AdapterBatchData
,
AdapterBatchMetadata
,
)
__all__
=
[
"AdapterBatchData"
,
"AdapterBatchMetadata"
,
]
server/text_generation_server/adapters/config.py
0 → 100644
View file @
efd602c8
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
Optional
,
Set
,
Tuple
import
torch
from
text_generation_server.adapters.weights
import
AdapterWeights
if
TYPE_CHECKING
:
from
text_generation_server.models.model
import
Model
@
dataclass
class
ModuleMap
:
module_name
:
str
module_weights
:
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
str
]]
@
dataclass
class
AdapterConfig
(
ABC
):
base_model_name_or_path
:
str
@
abstractmethod
def
map_weights_for_model
(
self
,
adapter_weights
:
Dict
[
int
,
AdapterWeights
],
weight_names
:
Tuple
[
str
],
)
->
Tuple
[
ModuleMap
,
Set
[
str
]]:
pass
@
abstractmethod
def
load_batched_adapter_weights
(
self
,
model
:
"Model"
,
module_map
:
ModuleMap
,
layer_type
:
str
,
unused_weight_names
:
Set
[
str
],
dynamic
:
bool
,
)
->
Optional
[
AdapterWeights
]:
pass
server/text_generation_server/adapters/lora.py
0 → 100644
View file @
efd602c8
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
from
peft
import
LoraConfig
as
_LoraConfig
from
torch.distributed
import
ProcessGroup
from
text_generation_server.adapters.config
import
AdapterConfig
,
ModuleMap
from
text_generation_server.adapters.weights
import
(
AdapterBatchMetadata
,
AdapterWeights
,
BatchAdapterWeights
,
)
from
text_generation_server.utils.sgmv
import
(
BGMV_MAX_RANK
,
MAX_RANK_CUSTOM
,
get_tmp_tensors
,
orient_for_rank
,
pad_rank
,
use_cutlass_shrink
,
)
if
TYPE_CHECKING
:
from
text_generation_server.models.model
import
Model
def
get_start_stop_idxs_for_rank
(
offset
,
size
,
rank
,
world_size
):
block_size
=
size
//
world_size
start
=
offset
+
rank
*
block_size
stop
=
offset
+
(
rank
+
1
)
*
block_size
return
start
,
stop
def
shard_on_dim
(
t
:
torch
.
Tensor
,
dim
:
int
,
process_group
:
torch
.
distributed
.
ProcessGroup
):
world_size
=
process_group
.
size
()
rank
=
process_group
.
rank
()
size
=
t
.
shape
[
dim
]
start
,
stop
=
get_start_stop_idxs_for_rank
(
0
,
size
,
rank
,
world_size
)
if
dim
==
0
:
tensor
=
t
[
start
:
stop
]
elif
dim
==
1
:
tensor
=
t
[:,
start
:
stop
]
else
:
raise
NotImplementedError
(
"Let's make that generic when needed"
)
return
tensor
def
shard_lora_weights
(
weights_a
:
List
[
torch
.
Tensor
],
weights_b
:
List
[
torch
.
Tensor
],
split_dim
:
int
,
process_group
:
ProcessGroup
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
# [hidden_size, r]
weights_a
=
[
shard_on_dim
(
w
,
dim
=
split_dim
,
process_group
=
process_group
)
for
w
in
weights_a
]
# [r, hidden_size]
weights_b
=
[
shard_on_dim
(
w
,
dim
=
1
,
process_group
=
process_group
)
for
w
in
weights_b
]
return
weights_a
,
weights_b
@
dataclass
class
LoraConfig
(
AdapterConfig
):
r
:
int
target_modules
:
Optional
[
Union
[
List
[
str
],
str
]]
fan_in_fan_out
:
bool
lora_alpha
:
int
use_rslora
:
bool
def
map_weights_for_model
(
self
,
adapter_weights
:
Dict
[
int
,
AdapterWeights
],
weight_names
:
Tuple
[
str
],
)
->
Tuple
[
ModuleMap
,
Set
[
str
]]:
adapter_weight_names
=
set
()
module_map
=
{}
for
weight_name
in
weight_names
:
lora_a_name
=
f
"base_model.model.
{
weight_name
}
.lora_A.weight"
lora_b_name
=
f
"base_model.model.
{
weight_name
}
.lora_B.weight"
if
lora_a_name
not
in
adapter_weights
or
lora_b_name
not
in
adapter_weights
:
continue
module_map
[
weight_name
]
=
{
"lora_A"
:
(
adapter_weights
[
lora_a_name
],
lora_a_name
),
"lora_B"
:
(
adapter_weights
[
lora_b_name
],
lora_b_name
),
}
adapter_weight_names
.
add
(
lora_a_name
)
adapter_weight_names
.
add
(
lora_b_name
)
return
module_map
,
adapter_weight_names
def
load_batched_adapter_weights
(
self
,
model
:
"Model"
,
module_map
:
Dict
[
str
,
Dict
],
layer_type
:
str
,
unused_weight_names
:
Set
[
str
],
dynamic
:
bool
,
)
->
Optional
[
AdapterWeights
]:
return
LoraWeights
.
load
(
self
,
model
,
module_map
,
layer_type
,
unused_weight_names
,
)
@
classmethod
def
load
(
cls
,
adapter_id
:
str
,
api_token
:
str
)
->
"LoraConfig"
:
hf_config
=
_LoraConfig
.
from_pretrained
(
adapter_id
,
token
=
api_token
)
return
cls
(
base_model_name_or_path
=
hf_config
.
base_model_name_or_path
,
r
=
hf_config
.
r
,
target_modules
=
hf_config
.
target_modules
,
fan_in_fan_out
=
hf_config
.
fan_in_fan_out
,
lora_alpha
=
hf_config
.
lora_alpha
,
use_rslora
=
(
hf_config
.
use_rslora
if
hasattr
(
hf_config
,
"use_rslora"
)
else
False
),
)
class
LoraWeights
(
AdapterWeights
):
"""LoRA weights for a single adapter merged across all layers."""
def
__init__
(
self
,
weights_a
:
List
[
torch
.
Tensor
],
weights_b
:
List
[
torch
.
Tensor
],
adapter_config
:
LoraConfig
,
):
self
.
lora_a_r
=
weights_a
[
0
].
size
(
1
)
if
len
(
weights_a
)
>
0
else
1
self
.
lora_b_r
=
weights_b
[
0
].
size
(
0
)
if
len
(
weights_a
)
>
0
else
1
self
.
_use_cutlass_shrink
=
use_cutlass_shrink
(
self
.
lora_a_r
)
self
.
_is_transposed
=
False
# [num_layers, hidden_size, r]
weights_a
=
[
orient_for_rank
(
w
,
w
.
size
(
1
)).
contiguous
()
for
w
in
weights_a
]
self
.
_weights_a
=
torch
.
stack
(
weights_a
)
# [num_layers, r, hidden_size]
self
.
_weights_b
=
torch
.
stack
(
weights_b
)
self
.
adapter_config
=
adapter_config
@
property
def
weights_a
(
self
)
->
torch
.
Tensor
:
if
self
.
_is_transposed
:
self
.
_transpose_weights
()
return
self
.
_weights_a
@
property
def
weights_b
(
self
)
->
torch
.
Tensor
:
if
self
.
_is_transposed
:
self
.
_transpose_weights
()
return
self
.
_weights_b
@
property
def
weights_a_t
(
self
)
->
torch
.
Tensor
:
if
not
self
.
_is_transposed
:
self
.
_transpose_weights
()
return
self
.
_weights_a
@
property
def
weights_b_t
(
self
)
->
torch
.
Tensor
:
if
not
self
.
_is_transposed
:
self
.
_transpose_weights
()
return
self
.
_weights_b
def
_transpose_weights
(
self
):
if
self
.
_use_cutlass_shrink
:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self
.
_weights_a
=
self
.
_weights_a
.
transpose
(
1
,
2
).
contiguous
()
self
.
_weights_b
=
self
.
_weights_b
.
transpose
(
1
,
2
).
contiguous
()
self
.
_is_transposed
=
not
self
.
_is_transposed
@
classmethod
def
get_batch_types
(
cls
)
->
List
[
Type
[
BatchAdapterWeights
]]:
return
[
BatchLoraWeights
]
@
classmethod
def
load
(
cls
,
config
:
LoraConfig
,
model
:
"Model"
,
module_map
:
Dict
[
str
,
Dict
],
layer_type
:
str
,
unused_weight_names
:
Set
[
str
],
)
->
Optional
[
AdapterWeights
]:
nlayers
=
model
.
get_num_layers_for_type
(
layer_type
)
lora_a_list
=
[
None
]
*
nlayers
lora_b_list
=
[
None
]
*
nlayers
for
layer_id
in
range
(
nlayers
):
key
=
(
layer_id
,
layer_type
)
weight_name
,
layer
=
model
.
target_to_layer
[
key
]
base_weight
=
layer
.
base_layer
.
linear
.
weight
base_device
=
base_weight
.
device
if
weight_name
not
in
module_map
:
# There is no LoRA weight for this layer type in the adapter
return
None
lora_a
,
lora_a_name
=
module_map
[
weight_name
][
"lora_A"
]
lora_a
=
lora_a
.
to
(
base_device
,
model
.
dtype
)
lora_b
,
lora_b_name
=
module_map
[
weight_name
][
"lora_B"
]
lora_b
=
lora_b
.
to
(
base_device
,
model
.
dtype
)
scale
=
get_scaling_factor
(
config
.
lora_alpha
,
config
.
r
,
uses_rslora
=
config
.
use_rslora
,
)
unused_weight_names
.
discard
(
lora_a_name
)
unused_weight_names
.
discard
(
lora_b_name
)
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list
[
layer_id
]
=
lora_a
.
transpose
(
0
,
1
)
lora_b_list
[
layer_id
]
=
lora_b
.
transpose
(
0
,
1
)
*
scale
# pad lora ranks to be compatible with sgmv
lora_a_list
=
[
pad_rank
(
w
,
dim
=
1
,
world_size
=
model
.
world_size
)
for
w
in
lora_a_list
]
lora_b_list
=
[
pad_rank
(
w
,
dim
=
0
,
world_size
=
model
.
world_size
)
for
w
in
lora_b_list
]
if
lora_a_list
:
# update rank if it was padded
padded_rank
=
lora_a_list
[
0
].
size
(
1
)
config
.
r
=
padded_rank
return
LoraWeights
(
*
shard_lora_weights
(
weights_a
=
lora_a_list
,
weights_b
=
lora_b_list
,
split_dim
=
0
if
model
.
is_row_parallel
(
layer_type
)
else
1
,
process_group
=
model
.
process_group
,
),
config
,
)
@
dataclass
class
RankSegments
:
rank
:
int
lora_a_ptr
:
torch
.
Tensor
lora_b_ptr
:
torch
.
Tensor
# prefill (sgmv)
tmp_shrink
:
torch
.
Tensor
tmp_expand
:
torch
.
Tensor
segment_starts
:
torch
.
Tensor
segment_ends
:
torch
.
Tensor
# decode (bgmv)
indices
:
torch
.
Tensor
@
dataclass
class
BatchLoraWeights
(
BatchAdapterWeights
):
lora_a
:
Dict
[
int
,
torch
.
Tensor
]
lora_b
:
Dict
[
int
,
torch
.
Tensor
]
adapter_index_configs
:
Dict
[
int
,
LoraConfig
]
rank_data
:
Dict
[
int
,
RankSegments
]
use_sgmv
:
bool
def
has_adapter
(
self
,
adapter_index
:
int
)
->
bool
:
return
adapter_index
in
self
.
adapter_index_configs
def
can_vectorize
(
self
,
pg
:
ProcessGroup
)
->
bool
:
return
all
(
rank_data
.
rank
//
pg
.
size
()
<=
MAX_RANK_CUSTOM
for
rank_data
in
self
.
rank_data
.
values
()
)
@
classmethod
def
key
(
cls
)
->
str
:
return
"lora"
@
classmethod
def
load
(
self
,
adapter_weights
:
Dict
[
int
,
AdapterWeights
],
meta
:
AdapterBatchMetadata
,
prefill
:
bool
,
prefill_head_indices
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
"BatchLoraWeights"
]:
adapter_weights
=
{
k
:
_convert_lora
(
v
)
for
k
,
v
in
adapter_weights
.
items
()}
adapter_weights
=
{
k
:
v
for
k
,
v
in
adapter_weights
.
items
()
if
isinstance
(
v
,
LoraWeights
)
}
if
not
adapter_weights
:
return
None
first_weights
=
next
(
iter
(
adapter_weights
.
values
()))
device
=
first_weights
.
weights_a
.
device
segment_indices
=
meta
.
segment_indices
lora_a
=
{
idx
:
adapter_weights
[
idx
].
weights_a
for
idx
in
segment_indices
if
idx
in
adapter_weights
}
lora_b
=
{
idx
:
adapter_weights
[
idx
].
weights_b
for
idx
in
segment_indices
if
idx
in
adapter_weights
}
max_rank
=
max
(
(
adapter_weights
[
idx
].
lora_a_r
for
idx
in
segment_indices
if
idx
in
adapter_weights
),
default
=
0
,
)
if
prefill
or
max_rank
>
BGMV_MAX_RANK
:
use_sgmv
=
True
lora_a_ptr
=
torch
.
tensor
(
[
(
adapter_weights
[
idx
].
weights_a
.
data_ptr
()
if
idx
in
adapter_weights
else
0
)
for
idx
in
segment_indices
],
dtype
=
torch
.
int64
,
device
=
device
,
)
lora_b_ptr
=
torch
.
tensor
(
[
(
adapter_weights
[
idx
].
weights_b
.
data_ptr
()
if
idx
in
adapter_weights
else
0
)
for
idx
in
segment_indices
],
dtype
=
torch
.
int64
,
device
=
device
,
)
else
:
use_sgmv
=
False
lora_a_ptr
=
torch
.
tensor
(
[
(
adapter_weights
[
idx
].
weights_a_t
.
data_ptr
()
if
idx
in
adapter_weights
else
0
)
for
idx
in
segment_indices
],
dtype
=
torch
.
int64
,
device
=
device
,
)
lora_b_ptr
=
torch
.
tensor
(
[
(
adapter_weights
[
idx
].
weights_b_t
.
data_ptr
()
if
idx
in
adapter_weights
else
0
)
for
idx
in
segment_indices
],
dtype
=
torch
.
int64
,
device
=
device
,
)
adapter_index_configs
=
{
idx
:
adapter_weights
[
idx
].
adapter_config
for
idx
in
segment_indices
if
idx
in
adapter_weights
}
adapter_to_segment
=
{
v
:
k
for
k
,
v
in
enumerate
(
segment_indices
)}
rank_indices
=
defaultdict
(
list
)
for
segment_idx
,
adapter_idx
in
enumerate
(
segment_indices
):
if
adapter_idx
not
in
adapter_weights
:
continue
rank_indices
[
adapter_weights
[
adapter_idx
].
lora_a_r
].
append
(
segment_idx
)
if
prefill_head_indices
is
not
None
:
j
,
prefill_head_segment_starts
,
prefill_head_segment_ends
=
1
,
[
0
],
[
0
]
for
head_index
in
prefill_head_indices
:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if
head_index
<
meta
.
adapter_segments
[
j
]:
prefill_head_segment_ends
[
-
1
]
+=
1
else
:
prefill_head_segment_starts
.
append
(
prefill_head_segment_ends
[
-
1
])
prefill_head_segment_ends
.
append
(
prefill_head_segment_ends
[
-
1
]
+
1
)
j
+=
1
rank_data
=
{}
for
rank
,
indices
in
rank_indices
.
items
():
tmp_shrink
=
None
tmp_expand
=
None
segment_starts
=
None
segment_ends
=
None
batch_indices
=
None
if
use_sgmv
:
lora_a_ptr_indices
=
lora_a_ptr
[
indices
]
tmp_shrink
,
tmp_expand
=
get_tmp_tensors
(
lora_a_ptr_indices
.
size
(
0
),
rank
,
device
)
segment_starts
=
meta
.
adapter_segments
[
indices
]
segment_ends
=
meta
.
adapter_segments
[[
i
+
1
for
i
in
indices
]]
if
prefill_head_indices
is
not
None
:
for
i
,
segment_index
in
enumerate
(
indices
):
segment_starts
[
i
]
=
prefill_head_segment_starts
[
segment_index
]
segment_ends
[
i
]
=
prefill_head_segment_ends
[
segment_index
]
else
:
rank_indices
=
set
(
indices
)
batch_indices
=
[
adapter_to_segment
[
idx
]
for
idx
in
meta
.
adapter_indices
.
tolist
()
]
batch_indices
=
[
idx
if
idx
in
rank_indices
else
-
1
for
idx
in
batch_indices
]
batch_indices
=
torch
.
tensor
(
batch_indices
,
dtype
=
torch
.
int64
,
device
=
device
)
rank_data
[
rank
]
=
RankSegments
(
rank
=
rank
,
tmp_shrink
=
tmp_shrink
,
tmp_expand
=
tmp_expand
,
lora_a_ptr
=
lora_a_ptr
[
indices
],
lora_b_ptr
=
lora_b_ptr
[
indices
],
segment_starts
=
segment_starts
,
segment_ends
=
segment_ends
,
indices
=
batch_indices
,
)
return
BatchLoraWeights
(
lora_a
=
lora_a
,
lora_b
=
lora_b
,
adapter_index_configs
=
adapter_index_configs
,
rank_data
=
rank_data
,
use_sgmv
=
use_sgmv
,
)
def
get_scaling_factor
(
lora_alpha
:
int
,
r
:
int
,
uses_rslora
:
bool
=
False
,
)
->
float
:
"""Computes the scaling factor for the lora weights."""
if
uses_rslora
:
return
lora_alpha
/
(
r
**
0.5
)
return
lora_alpha
/
r
def
_convert_lora
(
v
:
AdapterWeights
)
->
AdapterWeights
:
if
hasattr
(
v
,
"lora_weights"
):
return
v
.
lora_weights
return
v
server/text_generation_server/adapters/weights.py
0 → 100644
View file @
efd602c8
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from
abc
import
ABC
,
abstractclassmethod
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Type
import
torch
@
dataclass
class
AdapterBatchMetadata
:
# [batch_size]
adapter_indices
:
torch
.
Tensor
# [num_adapters]
adapter_set
:
Set
[
int
]
# [num_segments + 1]
adapter_segments
:
torch
.
Tensor
# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices
:
List
[
int
]
class
AdapterWeights
(
ABC
):
@
abstractclassmethod
def
get_batch_types
(
cls
)
->
List
[
Type
[
"BatchAdapterWeights"
]]:
pass
@
property
def
speculative_tokens
(
self
)
->
int
:
return
0
class
BatchAdapterWeights
(
ABC
):
@
abstractclassmethod
def
has_adapter
(
self
,
adapter_index
:
int
)
->
bool
:
pass
@
abstractclassmethod
def
key
(
cls
)
->
str
:
pass
@
abstractclassmethod
def
load
(
cls
,
adapter_weights
:
Dict
[
int
,
AdapterWeights
],
meta
:
"AdapterBatchMetadata"
,
prefill
:
bool
,
prefill_head_indices
:
torch
.
Tensor
,
)
->
Optional
[
"BatchAdapterWeights"
]:
pass
class
LayerAdapterWeights
:
"""Adapter weights that apply to a particular layer."""
def
__init__
(
self
):
self
.
adapter_weights
:
Dict
[
int
,
AdapterWeights
]
=
{}
def
add_adapter
(
self
,
adapter_idx
:
int
,
weights
:
AdapterWeights
):
self
.
adapter_weights
[
adapter_idx
]
=
weights
def
remove_adapter
(
self
,
adapter_idx
:
int
):
if
adapter_idx
not
in
self
.
adapter_weights
:
return
del
self
.
adapter_weights
[
adapter_idx
]
@
property
def
max_speculative_tokens
(
self
)
->
int
:
return
max
(
adapter_weights
.
speculative_tokens
for
adapter_weights
in
self
.
adapter_weights
.
values
()
)
def
is_empty
(
self
)
->
bool
:
return
len
(
self
.
adapter_weights
)
==
0
def
get_data
(
self
,
meta
:
AdapterBatchMetadata
,
prefill
:
bool
,
prefill_head_indices
:
Optional
[
torch
.
Tensor
],
)
->
Dict
[
str
,
BatchAdapterWeights
]:
# bucket adapters by batch class
adapter_batch_types
:
Dict
[
Type
[
BatchAdapterWeights
],
Dict
[
int
,
AdapterWeights
]
]
=
defaultdict
(
dict
)
for
adapter_index
,
adapter_weights
in
self
.
adapter_weights
.
items
():
for
batch_type
in
adapter_weights
.
get_batch_types
():
adapter_batch_types
[
batch_type
][
adapter_index
]
=
adapter_weights
batch_data
=
{}
for
batch_type
,
adapter_weights
in
adapter_batch_types
.
items
():
batched_weights
=
batch_type
.
load
(
adapter_weights
,
meta
,
prefill
,
prefill_head_indices
)
if
batched_weights
is
not
None
:
batch_data
[
batch_type
.
key
()]
=
batched_weights
return
batch_data
@
dataclass
class
AdapterBatchData
:
meta
:
AdapterBatchMetadata
# layer type -> adapter type -> batch weight data
data
:
Dict
[
str
,
Dict
[
str
,
BatchAdapterWeights
]]
prefill
:
bool
@
staticmethod
def
from_meta
(
meta
:
AdapterBatchMetadata
,
weights
:
Dict
[
str
,
LayerAdapterWeights
],
prefill
:
bool
,
prefill_head_indices
:
Optional
[
torch
.
Tensor
],
)
->
"AdapterBatchData"
:
data
=
{}
for
k
,
v
in
weights
.
items
():
if
v
.
is_empty
():
continue
data
[
k
]
=
v
.
get_data
(
meta
,
prefill
,
prefill_head_indices
if
k
==
"lm_head"
else
None
)
return
AdapterBatchData
(
meta
=
meta
,
data
=
data
,
prefill
=
prefill
)
def
ranks
(
self
)
->
Set
[
int
]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks
=
set
()
for
layer_data
in
self
.
data
.
values
():
lora_data
=
layer_data
.
get
(
"lora"
)
if
lora_data
is
None
:
continue
for
rank_data
in
lora_data
.
rank_data
.
values
():
ranks
.
add
(
rank_data
.
rank
)
return
ranks
def
layer_names
(
self
)
->
Set
[
str
]:
return
set
(
self
.
data
.
keys
())
def
adapter_keys
(
self
)
->
Set
[
str
]:
adapter_keys
=
set
()
for
layer_data
in
self
.
data
.
values
():
adapter_keys
.
update
(
layer_data
.
keys
())
return
adapter_keys
@
property
def
max_rank
(
self
)
->
int
:
ranks
=
self
.
ranks
()
return
max
(
ranks
)
if
len
(
ranks
)
>
0
else
0
server/text_generation_server/cli.py
View file @
efd602c8
...
@@ -19,7 +19,9 @@ class Quantization(str, Enum):
...
@@ -19,7 +19,9 @@ class Quantization(str, Enum):
gptq
=
"gptq"
gptq
=
"gptq"
awq
=
"awq"
awq
=
"awq"
eetq
=
"eetq"
eetq
=
"eetq"
exl2
=
"exl2"
fp8
=
"fp8"
fp8
=
"fp8"
marlin
=
"marlin"
class
Dtype
(
str
,
Enum
):
class
Dtype
(
str
,
Enum
):
...
@@ -40,6 +42,8 @@ def serve(
...
@@ -40,6 +42,8 @@ def serve(
logger_level
:
str
=
"INFO"
,
logger_level
:
str
=
"INFO"
,
json_output
:
bool
=
False
,
json_output
:
bool
=
False
,
otlp_endpoint
:
Optional
[
str
]
=
None
,
otlp_endpoint
:
Optional
[
str
]
=
None
,
otlp_service_name
:
str
=
"text-generation-inference.server"
,
max_input_tokens
:
Optional
[
int
]
=
None
,
):
):
if
sharded
:
if
sharded
:
assert
(
assert
(
...
@@ -73,7 +77,19 @@ def serve(
...
@@ -73,7 +77,19 @@ def serve(
# Setup OpenTelemetry distributed tracing
# Setup OpenTelemetry distributed tracing
if
otlp_endpoint
is
not
None
:
if
otlp_endpoint
is
not
None
:
setup_tracing
(
shard
=
os
.
getenv
(
"RANK"
,
0
),
otlp_endpoint
=
otlp_endpoint
)
setup_tracing
(
otlp_service_name
=
otlp_service_name
,
otlp_endpoint
=
otlp_endpoint
)
lora_adapter_ids
=
os
.
getenv
(
"LORA_ADAPTERS"
,
None
)
# split on comma and strip whitespace
lora_adapter_ids
=
(
[
x
.
strip
()
for
x
in
lora_adapter_ids
.
split
(
","
)]
if
lora_adapter_ids
else
[]
)
if
len
(
lora_adapter_ids
)
>
0
:
logger
.
warning
(
f
"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
# Downgrade enum into str for easier management later on
# Downgrade enum into str for easier management later on
quantize
=
None
if
quantize
is
None
else
quantize
.
value
quantize
=
None
if
quantize
is
None
else
quantize
.
value
...
@@ -89,6 +105,7 @@ def serve(
...
@@ -89,6 +105,7 @@ def serve(
)
)
server
.
serve
(
server
.
serve
(
model_id
,
model_id
,
lora_adapter_ids
,
revision
,
revision
,
sharded
,
sharded
,
quantize
,
quantize
,
...
@@ -96,6 +113,7 @@ def serve(
...
@@ -96,6 +113,7 @@ def serve(
dtype
,
dtype
,
trust_remote_code
,
trust_remote_code
,
uds_path
,
uds_path
,
max_input_tokens
,
)
)
...
@@ -108,6 +126,7 @@ def download_weights(
...
@@ -108,6 +126,7 @@ def download_weights(
logger_level
:
str
=
"INFO"
,
logger_level
:
str
=
"INFO"
,
json_output
:
bool
=
False
,
json_output
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
merge_lora
:
bool
=
False
,
):
):
# Remove default handler
# Remove default handler
logger
.
remove
()
logger
.
remove
()
...
@@ -138,47 +157,53 @@ def download_weights(
...
@@ -138,47 +157,53 @@ def download_weights(
)
is
not
None
)
is
not
None
if
not
is_local_model
:
if
not
is_local_model
:
try
:
# TODO: maybe reverse the default value of merge_lora?
adapter_config_filename
=
hf_hub_download
(
# currently by default we don't merge the weights with the base model
model_id
,
revision
=
revision
,
filename
=
"adapter_config.json"
if
merge_lora
:
)
try
:
utils
.
download_and_unload_peft
(
adapter_config_filename
=
hf_hub_download
(
model_id
,
revision
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
=
revision
,
filename
=
"adapter_config.json"
)
)
is_local_model
=
True
utils
.
download_and_unload_peft
(
utils
.
weight_files
(
model_id
,
revision
,
extension
)
model_id
,
revision
,
trust_remote_code
=
trust_remote_code
return
)
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
is_local_model
=
True
pass
utils
.
weight_files
(
model_id
,
revision
,
extension
)
return
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
pass
else
:
try
:
utils
.
peft
.
download_peft
(
model_id
,
revision
,
trust_remote_code
=
trust_remote_code
)
except
Exception
:
pass
try
:
try
:
import
json
import
json
medusa_head
=
hf_hub_download
(
config
=
hf_hub_download
(
model_id
,
revision
=
revision
,
filename
=
"medusa_lm_head.safetensors"
)
medusa_config
=
hf_hub_download
(
model_id
,
revision
=
revision
,
filename
=
"config.json"
model_id
,
revision
=
revision
,
filename
=
"config.json"
)
)
with
open
(
medusa_
config
,
"r"
)
as
f
:
with
open
(
config
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
model_id
=
config
[
"base_model_name_or_path"
]
base_model_id
=
config
.
get
(
"base_model_name_or_path"
,
None
)
revision
=
"main"
if
base_model_id
and
base_model_id
!=
model_id
:
try
:
try
:
utils
.
weight_files
(
model_id
,
revision
,
extension
)
logger
.
info
(
f
"Downloading parent model
{
base_model_id
}
"
)
logger
.
info
(
download_weights
(
f
"Files for parent
{
model_id
}
are already present on the host. "
model_id
=
base_model_id
,
"Skipping download."
revision
=
"main"
,
)
extension
=
extension
,
return
auto_convert
=
auto_convert
,
# Local files not found
logger_level
=
logger_level
,
except
(
json_output
=
json_output
,
utils
.
LocalEntryNotFoundError
,
trust_remote_code
=
trust_remote_code
,
FileNotFoundError
,
)
utils
.
EntryNotFoundError
,
except
Exception
:
):
pass
pass
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
pass
pass
...
@@ -195,31 +220,6 @@ def download_weights(
...
@@ -195,31 +220,6 @@ def download_weights(
if
not
extension
==
".safetensors"
or
not
auto_convert
:
if
not
extension
==
".safetensors"
or
not
auto_convert
:
raise
e
raise
e
elif
(
Path
(
model_id
)
/
"medusa_lm_head.safetensors"
).
exists
():
# Try to load as a local Medusa model
try
:
import
json
medusa_head
=
Path
(
model_id
)
/
"medusa_lm_head.safetensors"
medusa_config
=
Path
(
model_id
)
/
"config.json"
with
open
(
medusa_config
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
model_id
=
config
[
"base_model_name_or_path"
]
revision
=
"main"
try
:
utils
.
weight_files
(
model_id
,
revision
,
extension
)
logger
.
info
(
f
"Files for parent
{
model_id
}
are already present on the host. "
"Skipping download."
)
return
# Local files not found
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
pass
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
pass
elif
(
Path
(
model_id
)
/
"adapter_config.json"
).
exists
():
elif
(
Path
(
model_id
)
/
"adapter_config.json"
).
exists
():
# Try to load as a local PEFT model
# Try to load as a local PEFT model
try
:
try
:
...
@@ -230,14 +230,43 @@ def download_weights(
...
@@ -230,14 +230,43 @@ def download_weights(
return
return
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
pass
pass
elif
(
Path
(
model_id
)
/
"config.json"
).
exists
():
# Try to load as a local Medusa model
try
:
import
json
config
=
Path
(
model_id
)
/
"config.json"
with
open
(
config
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
base_model_id
=
config
.
get
(
"base_model_name_or_path"
,
None
)
if
base_model_id
:
try
:
logger
.
info
(
f
"Downloading parent model
{
base_model_id
}
"
)
download_weights
(
model_id
=
base_model_id
,
revision
=
"main"
,
extension
=
extension
,
auto_convert
=
auto_convert
,
logger_level
=
logger_level
,
json_output
=
json_output
,
trust_remote_code
=
trust_remote_code
,
)
except
Exception
:
pass
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
):
pass
# Try to see if there are local pytorch weights
# Try to see if there are local pytorch weights
try
:
try
:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files
=
utils
.
weight_files
(
model_id
,
revision
,
".bin"
)
try
:
local_pt_files
=
utils
.
weight_files
(
model_id
,
revision
,
".bin"
)
except
Exception
:
local_pt_files
=
utils
.
weight_files
(
model_id
,
revision
,
".pt"
)
# No local pytorch weights
# No local pytorch weights
except
utils
.
LocalEntryNotFoundError
:
except
(
utils
.
LocalEntryNotFoundError
,
utils
.
EntryNotFoundError
)
:
if
extension
==
".safetensors"
:
if
extension
==
".safetensors"
:
logger
.
warning
(
logger
.
warning
(
f
"No safetensors weights found for model
{
model_id
}
at revision
{
revision
}
. "
f
"No safetensors weights found for model
{
model_id
}
at revision
{
revision
}
. "
...
@@ -312,7 +341,7 @@ def quantize(
...
@@ -312,7 +341,7 @@ def quantize(
logger_level
=
logger_level
,
logger_level
=
logger_level
,
json_output
=
json_output
,
json_output
=
json_output
,
)
)
from
text_generation_server.
util
s.gptq.quantize
import
quantize
from
text_generation_server.
layer
s.gptq.quantize
import
quantize
quantize
(
quantize
(
model_id
=
model_id
,
model_id
=
model_id
,
...
...
server/text_generation_server/layers/__init__.py
0 → 100644
View file @
efd602c8
from
text_generation_server.layers.tensor_parallel
import
(
TensorParallelColumnLinear
,
TensorParallelRowLinear
,
TensorParallelEmbedding
,
)
from
text_generation_server.layers.linear
import
(
get_linear
,
FastLinear
,
)
from
text_generation_server.layers.speculative
import
SpeculativeHead
# Just to add the `load` methods.
from
text_generation_server.layers.layernorm
import
load_layer_norm
from
text_generation_server.layers.conv
import
load_conv2d
from
text_generation_server.layers.lora
import
(
LoraLinear
,
TensorParallelMultiAdapterLinear
,
TensorParallelAdapterRowLinear
,
)
server/text_generation_server/layers/attention/__init__.py
0 → 100644
View file @
efd602c8
from
text_generation_server.utils.import_utils
import
SYSTEM
import
os
from
.common
import
Seqlen
if
os
.
getenv
(
"USE_FLASH_ATTENTION"
,
""
).
lower
()
==
"false"
:
raise
ImportError
(
"`USE_FLASH_ATTENTION` is false."
)
if
SYSTEM
==
"cuda"
:
from
.cuda
import
attention
,
paged_attention
,
reshape_and_cache
,
SUPPORTS_WINDOWING
elif
SYSTEM
==
"rocm"
:
from
.rocm
import
attention
,
paged_attention
,
reshape_and_cache
,
SUPPORTS_WINDOWING
elif
SYSTEM
==
"ipex"
:
from
.ipex
import
attention
,
paged_attention
,
reshape_and_cache
,
SUPPORTS_WINDOWING
else
:
raise
ImportError
(
f
"System
{
SYSTEM
}
doesn't support flash/paged attention"
)
server/text_generation_server/layers/attention/common.py
0 → 100644
View file @
efd602c8
from
dataclasses
import
dataclass
from
text_generation_server.models.globals
import
FLASH_DECODING
import
torch
from
typing
import
Optional
if
FLASH_DECODING
:
@
dataclass
class
Seqlen
:
input_lengths
:
torch
.
Tensor
cu_seqlen_q
:
Optional
[
torch
.
Tensor
]
cu_seqlen_k
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
input_lengths
):
self
.
input_lengths
=
input_lengths
device
=
self
.
input_lengths
.
device
shape
=
self
.
input_lengths
.
shape
cu_seqlen_q
=
torch
.
arange
(
shape
[
0
]
+
1
,
device
=
device
,
dtype
=
torch
.
int32
,
)
cu_seqlen_k
=
torch
.
zeros
(
shape
[
-
1
]
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch
.
cumsum
(
self
.
input_lengths
,
-
1
,
out
=
cu_seqlen_k
[
1
:])
self
.
cu_seqlen_q
=
cu_seqlen_q
self
.
cu_seqlen_k
=
cu_seqlen_k
def
clamp
(
self
,
max
):
# Flash decoding doesn't need to clamp
return
self
else
:
@
dataclass
class
Seqlen
:
input_lengths
:
torch
.
Tensor
def
clamp
(
self
,
max
):
return
Seqlen
(
torch
.
clamp
(
self
.
input_lengths
,
max
=
max
))
server/text_generation_server/layers/attention/cuda.py
0 → 100644
View file @
efd602c8
import
torch
from
text_generation_server.utils.import_utils
import
SYSTEM
from
text_generation_server.models.globals
import
FLASH_DECODING
,
BLOCK_SIZE
from
text_generation_server.layers.attention
import
Seqlen
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
is_sm75
=
major
==
7
and
minor
==
5
_PARTITION_SIZE
=
512
try
:
from
vllm._C
import
cache_ops
from
vllm._C
import
ops
except
Exception
as
e
:
raise
ImportError
(
f
"Could not import vllm paged attention. Make sure your installation is correct. Complete error:
{
e
}
"
)
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
):
if
FLASH_DECODING
:
shape
=
key_cache
.
shape
key_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
key
value_cache
.
view
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])[
slots
]
=
value
else
:
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
,
"auto"
,
1.0
)
def
paged_attention
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
kv_head_mapping
:
torch
.
Tensor
,
softmax_scale
:
float
,
block_tables
:
torch
.
Tensor
,
seqlen
:
Seqlen
,
max_s
:
int
,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
# block_size = value_cache.shape[3]
block_size
=
BLOCK_SIZE
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
max_s
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if
FLASH_DECODING
:
max_q
=
1
max_k
=
max_s
import
flash_attn_2_cuda
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
out2
=
flash_attn_2_cuda
.
varlen_fwd
(
query
,
key_cache
,
value_cache
,
None
,
seqlen
.
cu_seqlen_q
,
seqlen
.
cu_seqlen_k
,
None
,
block_tables
,
None
,
max_q
,
max_k
,
0.0
,
# dropout
softmax_scale
,
False
,
# zero_tensors
True
,
# causal
-
1
,
# Window_left
-
1
,
# Window right
False
,
# return softmax
None
,
# generator
)
return
out2
[
0
]
else
:
input_lengths
=
seqlen
.
input_lengths
from
vllm._C
import
ops
use_v1
=
max_s
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
if
use_v1
:
ops
.
paged_attention_v1
(
out
,
query
,
key_cache
,
value_cache
,
kv_head_mapping
,
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
None
,
"auto"
,
1.0
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
out
.
dtype
,
device
=
out
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
out
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
out
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
kv_head_mapping
,
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
None
,
"auto"
,
1.0
,
)
return
out
try
:
import
flash_attn_2_cuda
V2
=
True
except
ImportError
:
try
:
import
flash_attn_cuda
V2
=
False
except
ImportError
as
e
:
if
major
>=
8
:
architecture_suffix
=
f
"-
{
SYSTEM
}
"
raise
ImportError
(
"Flash Attention V2 is not installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f
"or install flash attention v2 with `cd server && make install install-flash-attention-v2
{
architecture_suffix
}
`"
)
elif
is_sm75
:
raise
ImportError
(
"Flash Attention is not installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
from
e
else
:
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported"
)
from
e
SUPPORTS_WINDOWING
=
V2
if
V2
:
def
attention
(
q
,
k
,
v
,
out
,
cu_seqlens
,
max_s
,
softmax_scale
,
window_size_left
=-
1
,
causal
=
True
,
):
if
window_size_left
<=
0
and
window_size_left
!=
-
1
:
raise
ValueError
(
"`window_size_left` must be > 0 or -1"
)
return
flash_attn_2_cuda
.
varlen_fwd
(
q
,
k
,
v
,
out
,
cu_seqlens
,
cu_seqlens
,
None
,
None
,
None
,
max_s
,
max_s
,
0.0
,
softmax_scale
,
False
,
causal
,
window_size_left
,
0
,
False
,
None
,
)
else
:
def
attention
(
q
,
k
,
v
,
out
,
cu_seqlens
,
max_s
,
softmax_scale
,
window_size_left
=-
1
,
):
if
window_size_left
!=
-
1
:
raise
NotImplementedError
(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if
k
.
shape
[
1
]
!=
q
.
shape
[
1
]:
# MQA expand
if
k
.
shape
[
1
]
==
1
:
k
=
k
.
expand
(
-
1
,
q
.
shape
[
1
],
-
1
)
# Grouped attention reshape
else
:
original_shape
=
k
.
shape
k
=
(
k
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
-
1
)
.
reshape
(
original_shape
[
0
],
-
1
,
original_shape
[
2
])
)
if
v
.
shape
[
1
]
!=
q
.
shape
[
1
]:
# MQA expand
if
v
.
shape
[
1
]
==
1
:
v
=
v
.
expand
(
-
1
,
q
.
shape
[
1
],
-
1
)
# Grouped attention reshape
else
:
original_shape
=
v
.
shape
v
=
(
v
.
unsqueeze
(
2
)
.
expand
(
-
1
,
-
1
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
-
1
)
.
reshape
(
original_shape
[
0
],
-
1
,
original_shape
[
2
])
)
return
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
out
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
0.0
,
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
server/text_generation_server/layers/attention/flash_attn_triton.py
0 → 100644
View file @
efd602c8
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import
torch
import
triton
import
triton.language
as
tl
torch_dtype
:
tl
.
constexpr
=
torch
.
float16
@
triton
.
jit
def
cdiv_fn
(
x
,
y
):
return
(
x
+
y
-
1
)
//
y
@
triton
.
jit
def
max_fn
(
x
,
y
):
return
tl
.
math
.
max
(
x
,
y
)
@
triton
.
jit
def
dropout_offsets
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
ms
=
tl
.
arange
(
0
,
m
)
ns
=
tl
.
arange
(
0
,
n
)
return
philox_offset
+
ms
[:,
None
]
*
stride
+
ns
[
None
,
:]
@
triton
.
jit
def
dropout_rng
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
rng_offsets
=
dropout_offsets
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
).
to
(
tl
.
uint32
)
# TODO: use tl.randint for better performance
return
tl
.
rand
(
philox_seed
,
rng_offsets
)
@
triton
.
jit
def
dropout_mask
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
):
rng_output
=
dropout_rng
(
philox_seed
,
philox_offset
,
dropout_p
,
m
,
n
,
stride
)
rng_keep
=
rng_output
>
dropout_p
return
rng_keep
@
triton
.
jit
def
load_fn
(
block_ptr
,
first
,
second
,
pad
):
if
first
and
second
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
0
,
1
),
padding_option
=
pad
)
elif
first
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
0
,),
padding_option
=
pad
)
elif
second
:
tensor
=
tl
.
load
(
block_ptr
,
boundary_check
=
(
1
,),
padding_option
=
pad
)
else
:
tensor
=
tl
.
load
(
block_ptr
)
return
tensor
@
triton
.
jit
def
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
actual_seqlen_k
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
block_min
,
block_max
,
offs_n_causal
,
masked_blocks
,
n_extra_tokens
,
bias_ptr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
OFFS_M
:
tl
.
constexpr
,
OFFS_N
:
tl
.
constexpr
,
PRE_LOAD_V
:
tl
.
constexpr
,
MASK_STEPS
:
tl
.
constexpr
,
ENABLE_DROPOUT
:
tl
.
constexpr
,
RETURN_ENCODED_SOFTMAX
:
tl
.
constexpr
,
PADDED_HEAD
:
tl
.
constexpr
,
):
# loop over k, v, and update accumulator
for
start_n
in
range
(
block_min
,
block_max
,
BLOCK_N
):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k
=
load_fn
(
K_block_ptr
,
PADDED_HEAD
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
"zero"
,
)
if
PRE_LOAD_V
:
v
=
load_fn
(
V_block_ptr
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
PADDED_HEAD
,
"zero"
,
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if
MASK_STEPS
:
# noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if
(
start_n
+
BLOCK_N
==
block_max
)
and
(
n_extra_tokens
!=
0
):
boundary_m
=
tl
.
full
([
BLOCK_M
],
actual_seqlen_k
,
dtype
=
tl
.
int32
)
size_n
=
start_n
+
OFFS_N
[
None
,
:]
mask
=
size_n
<
boundary_m
[:,
None
]
qk
=
tl
.
where
(
mask
,
qk
,
float
(
"-inf"
))
if
IS_CAUSAL
:
causal_boundary
=
start_n
+
offs_n_causal
causal_mask
=
OFFS_M
[:,
None
]
>=
causal_boundary
[
None
,
:]
qk
=
tl
.
where
(
causal_mask
,
qk
,
float
(
"-inf"
))
# -- compute qk ----
qk
+=
tl
.
dot
(
q
,
k
)
if
bias_ptr
is
not
None
:
bias
=
load_fn
(
bias_ptr
,
False
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
"zero"
)
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk
+=
bias
*
1.44269504089
m_ij
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
qk
=
qk
-
m_ij
[:,
None
]
p
=
tl
.
math
.
exp2
(
qk
)
# CAVEAT: Must update l_ij before applying dropout
l_ij
=
tl
.
sum
(
p
,
1
)
if
ENABLE_DROPOUT
:
philox_offset
=
(
batch_philox_offset
+
start_m
*
BLOCK_M
*
actual_seqlen_k
+
start_n
-
BLOCK_N
)
keep
=
dropout_mask
(
philox_seed
,
philox_offset
,
dropout_p
,
BLOCK_M
,
BLOCK_N
,
actual_seqlen_k
,
)
if
RETURN_ENCODED_SOFTMAX
:
tl
.
store
(
encoded_softmax_block_ptr
,
tl
.
where
(
keep
,
p
,
-
p
).
to
(
encoded_softmax_block_ptr
.
type
.
element_ty
),
)
p
=
tl
.
where
(
keep
,
p
,
0.0
)
elif
RETURN_ENCODED_SOFTMAX
:
tl
.
store
(
encoded_softmax_block_ptr
,
p
.
to
(
encoded_softmax_block_ptr
.
type
.
element_ty
),
)
# -- update output accumulator --
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_ij
)
acc
=
acc
*
alpha
[:,
None
]
if
not
PRE_LOAD_V
:
v
=
load_fn
(
V_block_ptr
,
MASK_STEPS
and
(
n_extra_tokens
!=
0
),
PADDED_HEAD
,
"zero"
,
)
# -- update m_i and l_i
l_i
=
l_i
*
alpha
+
l_ij
# update m_i and l_i
m_i
=
m_ij
acc
+=
tl
.
dot
(
p
.
to
(
V_block_ptr
.
type
.
element_ty
),
v
)
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
BLOCK_N
,
0
))
K_block_ptr
=
tl
.
advance
(
K_block_ptr
,
(
0
,
BLOCK_N
))
if
bias_ptr
is
not
None
:
bias_ptr
=
tl
.
advance
(
bias_ptr
,
(
0
,
BLOCK_N
))
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
advance
(
encoded_softmax_block_ptr
,
(
0
,
BLOCK_N
)
)
return
acc
,
l_i
,
m_i
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
2
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"waves_per_eu"
:
2
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
256
,
"BLOCK_N"
:
128
,
"waves_per_eu"
:
2
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
3
,
"PRE_LOAD_V"
:
True
,
},
num_stages
=
1
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
3
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
64
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
4
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
8
,
),
triton
.
Config
(
{
"BLOCK_M"
:
32
,
"BLOCK_N"
:
32
,
"waves_per_eu"
:
4
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
8
,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton
.
Config
(
{
"BLOCK_M"
:
16
,
"BLOCK_N"
:
16
,
"waves_per_eu"
:
1
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
4
,
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
1
,
"PRE_LOAD_V"
:
False
,
},
num_stages
=
1
,
num_warps
=
4
,
),
],
key
=
[
"IS_CAUSAL"
,
"dropout_p"
,
"BLOCK_DMODEL"
],
)
@
triton
.
jit
def
attn_fwd
(
Q
,
K
,
V
,
bias
,
sm_scale
,
L
,
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
stride_bz
,
stride_bh
,
stride_bm
,
stride_bn
,
cu_seqlens_q
,
cu_seqlens_k
,
dropout_p
,
philox_seed
,
philox_offset_base
,
encoded_softmax
,
HQ
:
tl
.
constexpr
,
HK
:
tl
.
constexpr
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
VARLEN
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
PRE_LOAD_V
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
ENABLE_DROPOUT
:
tl
.
constexpr
,
RETURN_ENCODED_SOFTMAX
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
off_h_q
=
tl
.
program_id
(
1
)
off_z
=
tl
.
program_id
(
2
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
if
VARLEN
:
cu_seqlens_q_start
=
tl
.
load
(
cu_seqlens_q
+
off_z
)
cu_seqlens_q_end
=
tl
.
load
(
cu_seqlens_q
+
off_z
+
1
)
seqlen_q
=
cu_seqlens_q_end
-
cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if
start_m
*
BLOCK_M
>
seqlen_q
:
return
cu_seqlens_k_start
=
tl
.
load
(
cu_seqlens_k
+
off_z
)
cu_seqlens_k_end
=
tl
.
load
(
cu_seqlens_k
+
off_z
+
1
)
seqlen_k
=
cu_seqlens_k_end
-
cu_seqlens_k_start
else
:
cu_seqlens_q_start
=
0
cu_seqlens_k_start
=
0
seqlen_q
=
MAX_SEQLENS_Q
seqlen_k
=
MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks
=
cdiv_fn
(
seqlen_k
,
BLOCK_N
)
if
IS_CAUSAL
:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen
=
cdiv_fn
(
(
start_m
+
1
)
*
BLOCK_M
+
seqlen_k
-
seqlen_q
,
BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks
=
min
(
n_blocks
,
n_blocks_seqlen
)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if
n_blocks
<=
0
:
o_offset
=
(
off_z
*
stride_oz
+
cu_seqlens_q_start
*
stride_om
+
off_h_q
*
stride_oh
)
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
Out
.
type
.
element_ty
)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE
:
tl
.
constexpr
=
HQ
//
HK
if
GROUP_SIZE
!=
1
:
off_h_k
=
off_h_q
//
GROUP_SIZE
else
:
off_h_k
=
off_h_q
n_extra_tokens
=
0
if
seqlen_k
<
BLOCK_N
:
n_extra_tokens
=
BLOCK_N
-
seqlen_k
elif
seqlen_k
%
BLOCK_N
:
n_extra_tokens
=
seqlen_k
%
BLOCK_N
PADDED_HEAD
:
tl
.
constexpr
=
ACTUAL_BLOCK_DMODEL
!=
BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset
=
off_z
*
stride_qz
+
off_h_q
*
stride_qh
+
cu_seqlens_q_start
*
stride_qm
Q_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q
+
q_offset
,
shape
=
(
seqlen_q
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_qm
,
stride_qk
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
k_offset
=
off_z
*
stride_kz
+
off_h_k
*
stride_kh
+
cu_seqlens_k_start
*
stride_kn
K_block_ptr
=
tl
.
make_block_ptr
(
base
=
K
+
k_offset
,
shape
=
(
ACTUAL_BLOCK_DMODEL
,
seqlen_k
),
strides
=
(
stride_kk
,
stride_kn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N
),
order
=
(
0
,
1
),
)
v_offset
=
off_z
*
stride_vz
+
off_h_k
*
stride_vh
+
cu_seqlens_k_start
*
stride_vk
V_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
+
v_offset
,
shape
=
(
seqlen_k
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_vk
,
stride_vn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_N
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
if
BIAS_TYPE
!=
0
:
bias_ptr
=
tl
.
make_block_ptr
(
base
=
bias
+
off_h_q
*
stride_bh
,
shape
=
(
seqlen_q
,
seqlen_k
),
strides
=
(
stride_bm
,
stride_bn
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
order
=
(
1
,
0
),
)
else
:
bias_ptr
=
None
if
ENABLE_DROPOUT
:
batch_philox_offset
=
(
philox_offset_base
+
(
off_z
*
HQ
+
off_h_q
)
*
seqlen_q
*
seqlen_k
)
else
:
batch_philox_offset
=
0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
make_block_ptr
(
base
=
encoded_softmax
+
off_h_q
*
seqlen_q
*
seqlen_k
,
shape
=
(
seqlen_q
,
seqlen_k
),
strides
=
(
seqlen_k
,
1
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
order
=
(
1
,
0
),
)
else
:
encoded_softmax_block_ptr
=
0
# initialize pointer to m and l
m_i
=
tl
.
full
([
BLOCK_M
],
float
(
"-inf"
),
dtype
=
tl
.
float32
)
l_i
=
tl
.
full
([
BLOCK_M
],
1.0
,
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale
=
sm_scale
*
1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q
=
load_fn
(
Q_block_ptr
,
True
,
PADDED_HEAD
,
"zero"
)
q
=
(
q
*
qk_scale
).
to
(
Q_block_ptr
.
type
.
element_ty
)
# Here we compute how many full and masked blocks we have.
padded_block_k
=
n_extra_tokens
!=
0
is_modulo_mn
=
not
padded_block_k
and
(
seqlen_q
%
BLOCK_M
==
0
)
if
IS_CAUSAL
:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks
=
BLOCK_M
//
BLOCK_N
+
(
not
is_modulo_mn
)
else
:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks
=
padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks
=
min
(
masked_blocks
,
n_blocks
)
n_full_blocks
=
n_blocks
-
masked_blocks
block_min
=
0
block_max
=
n_blocks
*
BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if
n_full_blocks
>
0
:
block_max
=
(
n_blocks
-
masked_blocks
)
*
BLOCK_N
acc
,
l_i
,
m_i
=
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
seqlen_k
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min
,
block_max
,
0
,
0
,
0
,
bias_ptr
,
# IS_CAUSAL, ....
False
,
BLOCK_M
,
BLOCK_DMODEL
,
BLOCK_N
,
offs_m
,
offs_n
,
# _, MASK_STEPS, ...
PRE_LOAD_V
,
False
,
ENABLE_DROPOUT
,
RETURN_ENCODED_SOFTMAX
,
PADDED_HEAD
,
)
block_min
=
block_max
block_max
=
n_blocks
*
BLOCK_N
tl
.
debug_barrier
()
# Remaining blocks, if any, are full / not masked.
if
masked_blocks
>
0
:
offs_n_causal
=
offs_n
+
(
seqlen_q
-
seqlen_k
)
if
IS_CAUSAL
else
0
K_block_ptr
=
tl
.
advance
(
K_block_ptr
,
(
0
,
n_full_blocks
*
BLOCK_N
))
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
n_full_blocks
*
BLOCK_N
,
0
))
if
bias_ptr
is
not
None
:
bias_ptr
=
tl
.
advance
(
bias_ptr
,
(
0
,
n_full_blocks
*
BLOCK_N
))
if
RETURN_ENCODED_SOFTMAX
:
encoded_softmax_block_ptr
=
tl
.
advance
(
encoded_softmax_block_ptr
,
(
0
,
n_full_blocks
)
)
acc
,
l_i
,
m_i
=
_attn_fwd_inner
(
acc
,
l_i
,
m_i
,
q
,
K_block_ptr
,
V_block_ptr
,
start_m
,
seqlen_k
,
dropout_p
,
philox_seed
,
batch_philox_offset
,
encoded_softmax_block_ptr
,
block_min
,
block_max
,
offs_n_causal
,
masked_blocks
,
n_extra_tokens
,
bias_ptr
,
IS_CAUSAL
,
BLOCK_M
,
BLOCK_DMODEL
,
BLOCK_N
,
offs_m
,
offs_n
,
# _, MASK_STEPS, ...
PRE_LOAD_V
,
True
,
ENABLE_DROPOUT
,
RETURN_ENCODED_SOFTMAX
,
PADDED_HEAD
,
)
# epilogue
acc
=
acc
/
l_i
[:,
None
]
if
ENABLE_DROPOUT
:
acc
=
acc
/
(
1
-
dropout_p
)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx
=
(
start_m
+
1
)
*
BLOCK_M
start_m_idx
=
start_m
*
BLOCK_M
causal_start_idx
=
seqlen_q
-
seqlen_k
acc
=
acc
.
to
(
Out
.
type
.
element_ty
)
if
IS_CAUSAL
:
# noqa: SIM102
if
causal_start_idx
>
start_m_idx
and
causal_start_idx
<
end_m_idx
:
out_mask_boundary
=
tl
.
full
(
(
BLOCK_DMODEL
,),
causal_start_idx
,
dtype
=
tl
.
int32
)
mask_m_offsets
=
start_m_idx
+
tl
.
arange
(
0
,
BLOCK_M
)
out_ptrs_mask
=
mask_m_offsets
[:,
None
]
>=
out_mask_boundary
[
None
,
:]
z
=
0.0
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset
=
off_z
*
stride_oz
+
cu_seqlens_q_start
*
stride_om
+
off_h_q
*
stride_oh
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
o_offset
,
shape
=
(
seqlen_q
,
ACTUAL_BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl
.
store
(
O_block_ptr
,
acc
,
boundary_check
=
(
0
,
1
))
def
check_args
(
q
,
k
,
v
,
o
,
varlen
=
True
,
max_seqlens
=
None
,
cu_seqlens_q
=
None
,
cu_seqlens_k
=
None
,
):
assert
q
.
dim
()
==
k
.
dim
()
and
q
.
dim
()
==
v
.
dim
()
if
varlen
:
assert
q
.
dim
()
==
3
total_q
,
nheads_q
,
head_size
=
q
.
shape
total_k
,
nheads_k
,
_
=
k
.
shape
assert
cu_seqlens_q
is
not
None
assert
cu_seqlens_k
is
not
None
assert
len
(
cu_seqlens_q
)
==
len
(
cu_seqlens_k
)
else
:
assert
q
.
dim
()
==
4
batch
,
nheads_q
,
seqlen_q
,
head_size
=
q
.
shape
_
,
nheads_k
,
seqlen_k
,
_
=
k
.
shape
assert
max_seqlens
>
0
assert
k
.
shape
==
v
.
shape
assert
q
.
shape
[
-
1
]
==
k
.
shape
[
-
1
]
and
q
.
shape
[
-
1
]
==
v
.
shape
[
-
1
]
# TODO: Change assert if we support qkl f8 and v f16
assert
q
.
dtype
==
k
.
dtype
and
q
.
dtype
==
v
.
dtype
# TODO: Fix assert to check head size <=256 once supported
assert
head_size
<=
128
assert
o
.
shape
==
q
.
shape
assert
(
nheads_q
%
nheads_k
)
==
0
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
o
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlens_q
,
max_seqlens_k
,
causal
=
False
,
sm_scale
=
1.0
,
bias
=
None
,
):
if
o
is
None
:
o
=
torch
.
empty_like
(
q
,
dtype
=
v
.
dtype
)
check_args
(
q
,
k
,
v
,
o
,
varlen
=
True
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
)
if
True
:
# varlen
total_q
,
nheads_q
,
head_size
=
q
.
shape
total_k
,
nheads_k
,
_
=
k
.
shape
batch
=
len
(
cu_seqlens_q
)
-
1
q_strides
=
(
0
,
q
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
2
))
k_strides
=
(
0
,
k
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
))
v_strides
=
(
0
,
v
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
))
o_strides
=
(
0
,
o
.
stride
(
1
),
o
.
stride
(
0
),
o
.
stride
(
2
))
else
:
batch
,
seqlen_q
,
nheads_q
,
head_size
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
q_strides
=
(
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
3
))
k_strides
=
(
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
k
.
stride
(
3
))
v_strides
=
(
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
v
.
stride
(
3
))
o_strides
=
(
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
o
.
stride
(
3
))
# Get closest power of 2 over or equal to 32.
padded_d_model
=
1
<<
(
head_size
-
1
).
bit_length
()
padded_d_model
=
max
(
padded_d_model
,
16
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
max_seqlens_q
,
META
[
"BLOCK_M"
]),
nheads_q
,
batch
,
)
encoded_softmax
=
None
# Seed the RNG so we get reproducible results for testing.
philox_seed
=
0x1BF52
philox_offset
=
0x1D4B42
if
bias
is
not
None
:
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
),
bias
.
stride
(
3
),
)
else
:
bias_strides
=
(
0
,
0
,
0
,
0
)
attn_fwd
[
grid
](
q
,
k
,
v
,
bias
,
sm_scale
,
None
,
o
,
*
q_strides
,
*
k_strides
,
*
v_strides
,
*
o_strides
,
*
bias_strides
,
cu_seqlens_q
,
cu_seqlens_k
,
dropout_p
=
0.0
,
philox_seed
=
philox_seed
,
philox_offset_base
=
philox_offset
,
encoded_softmax
=
encoded_softmax
,
HQ
=
nheads_q
,
HK
=
nheads_k
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
MAX_SEQLENS_Q
=
max_seqlens_q
,
MAX_SEQLENS_K
=
max_seqlens_k
,
IS_CAUSAL
=
causal
,
VARLEN
=
True
,
BLOCK_DMODEL
=
padded_d_model
,
BIAS_TYPE
=
0
if
bias
is
None
else
1
,
ENABLE_DROPOUT
=
False
,
RETURN_ENCODED_SOFTMAX
=
False
,
)
ctx
.
grid
=
grid
ctx
.
sm_scale
=
sm_scale
ctx
.
BLOCK_DMODEL
=
head_size
ctx
.
causal
=
causal
ctx
.
dropout_p
=
0.0
ctx
.
philox_seed
=
philox_seed
ctx
.
philox_offset
=
philox_offset
ctx
.
encoded_softmax
=
encoded_softmax
ctx
.
return_encoded_softmax
=
False
return
o
,
encoded_softmax
triton_attention
=
_attention
.
apply
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment