Commit efd602c8 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

last

parent f1b779fc
...@@ -6,14 +6,14 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p ...@@ -6,14 +6,14 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
...@@ -32,17 +32,17 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" ...@@ -32,17 +32,17 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -6,14 +6,14 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p ...@@ -6,14 +6,14 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13" filelock==3.14.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.3.1 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.64.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
...@@ -32,17 +32,17 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" ...@@ -32,17 +32,17 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13" regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.2 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" setuptools==70.0.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.41.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -29,6 +29,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
......
...@@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
......
...@@ -17,7 +17,12 @@ def get_test_model(): ...@@ -17,7 +17,12 @@ def get_test_model():
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b") tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel( model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu") "test_model_id",
torch.nn.Linear(1, 1),
tokenizer,
False,
torch.float32,
torch.device("cpu"),
) )
return model return model
......
...@@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="def", inputs="def",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="def")]),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
...@@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -32,6 +33,13 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>", inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
input_chunks=generate_pb2.Input(
chunks=[
generate_pb2.InputChunk(
text="<fim-prefix>def<fim-suffix>world<fim-middle>"
)
]
),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
......
...@@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters): ...@@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return generate_pb2.Request( return generate_pb2.Request(
id=0, id=0,
inputs="Test", inputs="Test",
input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text="Test")]),
prefill_logprobs=True, prefill_logprobs=True,
truncate=100, truncate=100,
parameters=default_pb_parameters, parameters=default_pb_parameters,
......
import torch import torch
from text_generation_server.utils.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
) )
......
import pytest
import torch
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.marlin import MarlinWeight
from types import SimpleNamespace
from typing import List, Optional, Dict, Union
from pathlib import Path
dummy_file_system = {
"test_weights": {
"layer.0.weight": torch.tensor(
[
[1, 2],
[3, 4],
],
dtype=torch.float32,
),
},
"test_weights_2": {
"layer.1337.weight": torch.tensor(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
],
dtype=torch.float32,
),
},
"test_get_weights_col_packed": {
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
},
"test_get_multi_weights_col": {
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
},
"test_get_multi_weights_row": {
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
},
"test_get_weights_col_gptq": {
"weight.qweight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
"weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32),
"weight.qzeros": torch.tensor(
[
[0, 1],
[1, 0],
],
dtype=torch.int32,
),
"weight.scales": torch.tensor(
[
[100.0, 100.0],
[100.0, 100.0],
],
dtype=torch.float16,
),
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
},
"test_get_weights_col_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
},
"test_get_multi_weights_row_gptq": {
"weight.qweight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32),
"weight.qzeros": torch.tensor(
[
[0, 1],
[1, 0],
],
dtype=torch.int32,
),
"weight.scales": torch.tensor(
[
[100.0, 100.0],
[100.0, 100.0],
],
dtype=torch.float16,
),
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
},
"test_get_multi_weights_col_gptq": {
"weight.qweight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32),
"weight.qzeros": torch.tensor(
[
[0, 1],
[1, 0],
],
dtype=torch.int32,
),
"weight.scales": torch.tensor(
[
[100.0, 100.0],
[100.0, 100.0],
],
dtype=torch.float16,
),
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
},
"test_get_weights_col_packed_gptq": {
"weight.qweight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.g_idx": torch.tensor([0, 1, 0, 1], dtype=torch.int32),
"weight.qzeros": torch.tensor(
[
[0, 1],
[1, 0],
],
dtype=torch.int32,
),
"weight.scales": torch.tensor(
[
[100.0, 100.0],
[100.0, 100.0],
],
dtype=torch.float16,
),
"gptq_bits": torch.tensor([8], dtype=torch.float32),
"gptq_groupsize": torch.tensor([2], dtype=torch.float32),
},
"test_get_weights_col_packed_exl2": {
"weight.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_row_exl2": {
"weight.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_col_exl2": {
"weight.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_weights_col_exl2": {
"weight.q_weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.int32,
),
"weight.q_scale": torch.tensor([8], dtype=torch.int32),
"weight.q_invperm": torch.tensor([1, 0, 3, 2], dtype=torch.int32),
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
"test_get_multi_weights_col_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
"test_get_weights_col_packed_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
}
class MockSlice:
def __init__(self, tensor):
self.tensor = tensor
def get_shape(self):
return self.tensor.shape
def __getitem__(self, idx):
return self.tensor[idx]
def mock_get_slice(tensor_name, filename):
tensor = dummy_file_system[filename][tensor_name]
return MockSlice(tensor)
def mock_handle(filename, device, dtype):
return SimpleNamespace(
get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename)
)
class MockSafeOpen:
def __init__(self, filename, framework, dummy_fs):
self.filename = filename
self.framework = framework
self.dummy_fs = dummy_fs
def keys(self):
return list(self.dummy_fs[self.filename].keys())
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class MockWeights(Weights):
def __init__(
self,
filenames: List[Union[Path, str]],
device,
dtype,
process_group,
dummy_fs,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None,
):
routing = {}
self.dummy_fs = dummy_fs
for filename in filenames:
with MockSafeOpen(filename, framework="pytorch", dummy_fs=dummy_fs) as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
if aliases is None:
aliases = {}
self.aliases = aliases
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self._handles = {}
def _get_handle(self, filename: Union[Path, str]):
if filename in self._handles:
return self._handles[filename]
else:
handle = mock_handle(filename, self.device, self.dtype)
self._handles[filename] = handle
return handle
def get_shape(self, tensor_name: str):
filename, _ = self.get_filename(tensor_name)
handle = self._get_handle(filename)
return handle.get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str):
filename, _ = self.get_filename(tensor_name)
handle = self._get_handle(filename)
return handle.get_slice(tensor_name).tensor
dummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1)
def test_weights():
weights = MockWeights(
[
"test_weights",
"test_weights_2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
assert weights.get_shape("layer.0.weight") == (2, 2)
assert weights.get_tensor("layer.1337.weight").shape == (2, 4)
def test_get_tensor():
weights = MockWeights(
[
"test_weights",
"test_weights_2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
assert torch.allclose(
weights.get_tensor("layer.0.weight"),
torch.tensor(
[
[1, 2],
[3, 4],
],
dtype=torch.float32,
),
)
assert torch.allclose(
weights.get_tensor("layer.1337.weight"),
torch.tensor(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
],
dtype=torch.float32,
),
)
def test_get_weights_col_packed():
weights = MockWeights(
[
"test_get_weights_col_packed",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = None
block_sizes = 1
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
assert torch.allclose(
w,
torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
)
def test_get_weights_col_packed_block_size():
weights = MockWeights(
[
"test_get_weights_col_packed",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = None
block_sizes = 2
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
assert torch.allclose(
w,
torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
)
def test_get_weights_col_packed_block_size_arr():
weights = MockWeights(
[
"test_get_weights_col_packed",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = None
block_sizes = [1, 1]
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
assert torch.allclose(
w,
torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
)
def test_get_multi_weights_col():
weights = MockWeights(
[
"test_get_multi_weights_col",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["weight", "weight"]
quantize = None
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
assert torch.allclose(
w,
torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
)
def test_get_multi_weights_row():
weights = MockWeights(
[
"test_get_multi_weights_row",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = None
w = weights.get_multi_weights_row(
prefix=prefix,
quantize=quantize,
)
assert torch.allclose(
w,
torch.tensor(
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
dtype=torch.float32,
),
)
# test_get_weights_col
def test_get_weights_col_awq():
weights = MockWeights(
[
"test_get_weights_col_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "awq"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
scales=torch.tensor(
[[100.0, 100.0], [100.0, 100.0]],
dtype=torch.float16,
),
g_idx=None,
bits=8.0,
groupsize=2.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_gtpq():
weights = MockWeights(
[
"test_get_weights_col_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "gptq"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_exl2():
weights = MockWeights(
[
"test_get_weights_col_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "exl2"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
scaled_scale_max = 0.3906 * 256
expected_weight = Exl2Weight(
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
q_scale=torch.tensor([8], dtype=torch.int32),
q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16),
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
q_groups=torch.tensor([4], dtype=torch.int16),
)
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
assert torch.allclose(
w.q_scale_max, expected_weight.q_scale_max
), "q_scale_max mismatch"
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_marlin():
weights = MockWeights(
[
"test_get_weights_col_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_weights_col_packed
def test_get_weights_col_packed_awq():
weights = MockWeights(
[
"test_get_weights_col_packed_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "awq"
block_sizes = 1
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
g_idx=None,
bits=8.0,
groupsize=2.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
@pytest.mark.skip(reason="Review expected functionality")
def test_get_weights_col_packed_exl2():
weights = MockWeights(
[
"test_get_weights_col_packed_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "exl2"
block_sizes = 1
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
scaled_scale_max = 0.3906 * 256
expected_weight = Exl2Weight(
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
q_scale=torch.tensor([8], dtype=torch.int32),
q_invperm=torch.tensor([1], dtype=torch.int16),
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
q_groups=torch.tensor([4], dtype=torch.int16),
)
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
assert torch.allclose(
w.q_scale_max, expected_weight.q_scale_max
), "q_scale_max mismatch"
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_packed_gptq():
weights = MockWeights(
[
"test_get_weights_col_packed_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_packed_marlin():
weights = MockWeights(
[
"test_get_weights_col_packed_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
print(expected_weight)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_col
def test_get_multi_weights_col_awq():
weights = MockWeights(
[
"test_get_multi_weights_col_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["weight"]
quantize = "awq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
g_idx=None,
bits=8.0,
groupsize=2.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_exl2():
weights = MockWeights(
[
"test_get_multi_weights_col_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "exl2"
try:
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
except ValueError as e:
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
def test_get_multi_weights_col_gptq():
weights = MockWeights(
[
"test_get_multi_weights_col_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_marlin():
weights = MockWeights(
[
"test_get_multi_weights_col_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_row
def test_get_multi_weights_row_awq():
weights = MockWeights(
[
"test_get_multi_weights_row_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "awq"
w = weights.get_multi_weights_row(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
g_idx=None,
bits=8.0,
groupsize=2.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert w.g_idx == expected_weight.g_idx, "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_exl2():
weights = MockWeights(
[
"test_get_multi_weights_row_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "exl2"
w = weights.get_multi_weights_row(
prefix=prefix,
quantize=quantize,
)
print(w)
scaled_scale_max = 0.3906 * 256
expected_weight = Exl2Weight(
q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
q_scale=torch.tensor([8], dtype=torch.int32),
q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16),
q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),
q_groups=torch.tensor([4], dtype=torch.int16),
)
assert torch.allclose(w.q_weight, expected_weight.q_weight), "q_weight mismatch"
assert torch.allclose(w.q_scale, expected_weight.q_scale), "q_scale mismatch"
assert torch.allclose(w.q_invperm, expected_weight.q_invperm), "q_invperm mismatch"
assert torch.allclose(
w.q_scale_max, expected_weight.q_scale_max
), "q_scale_max mismatch"
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_row_gptq():
weights = MockWeights(
[
"test_get_multi_weights_row_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "gptq"
w = weights.get_multi_weights_row(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),
qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),
g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),
bits=8.0,
groupsize=2.0,
use_exllama=False,
)
assert torch.allclose(w.qweight, expected_weight.qweight), "qweight mismatch"
assert torch.allclose(w.qzeros, expected_weight.qzeros), "qzeros mismatch"
assert torch.allclose(w.scales, expected_weight.scales), "scales mismatch"
assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch"
assert w.bits == expected_weight.bits, "bits mismatch"
assert w.groupsize == expected_weight.groupsize, "groupsize mismatch"
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_marlin():
weights = MockWeights(
[
"test_get_multi_weights_row_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_row(
prefix=prefix,
quantize=quantize,
)
expected_weight = MarlinWeight(
B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
)
assert torch.allclose(w.B, expected_weight.B), "B mismatch"
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/__init__.py
# License: Apache License Version 2.0, January 2004
from text_generation_server.adapters.weights import (
AdapterBatchData,
AdapterBatchMetadata,
)
__all__ = [
"AdapterBatchData",
"AdapterBatchMetadata",
]
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/config.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
@dataclass
class ModuleMap:
module_name: str
module_weights: Dict[str, Tuple[torch.Tensor, str]]
@dataclass
class AdapterConfig(ABC):
base_model_name_or_path: str
@abstractmethod
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
pass
@abstractmethod
def load_batched_adapter_weights(
self,
model: "Model",
module_map: ModuleMap,
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
pass
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/lora.py
# License: Apache License Version 2.0, January 2004
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
from torch.distributed import ProcessGroup
from text_generation_server.adapters.config import AdapterConfig, ModuleMap
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
BatchAdapterWeights,
)
from text_generation_server.utils.sgmv import (
BGMV_MAX_RANK,
MAX_RANK_CUSTOM,
get_tmp_tensors,
orient_for_rank,
pad_rank,
use_cutlass_shrink,
)
if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
start = offset + rank * block_size
stop = offset + (rank + 1) * block_size
return start, stop
def shard_on_dim(
t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup
):
world_size = process_group.size()
rank = process_group.rank()
size = t.shape[dim]
start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)
if dim == 0:
tensor = t[start:stop]
elif dim == 1:
tensor = t[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
return tensor
def shard_lora_weights(
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
split_dim: int,
process_group: ProcessGroup,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
# [hidden_size, r]
weights_a = [
shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a
]
# [r, hidden_size]
weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]
return weights_a, weights_b
@dataclass
class LoraConfig(AdapterConfig):
r: int
target_modules: Optional[Union[List[str], str]]
fan_in_fan_out: bool
lora_alpha: int
use_rslora: bool
def map_weights_for_model(
self,
adapter_weights: Dict[int, AdapterWeights],
weight_names: Tuple[str],
) -> Tuple[ModuleMap, Set[str]]:
adapter_weight_names = set()
module_map = {}
for weight_name in weight_names:
lora_a_name = f"base_model.model.{weight_name}.lora_A.weight"
lora_b_name = f"base_model.model.{weight_name}.lora_B.weight"
if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:
continue
module_map[weight_name] = {
"lora_A": (adapter_weights[lora_a_name], lora_a_name),
"lora_B": (adapter_weights[lora_b_name], lora_b_name),
}
adapter_weight_names.add(lora_a_name)
adapter_weight_names.add(lora_b_name)
return module_map, adapter_weight_names
def load_batched_adapter_weights(
self,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
dynamic: bool,
) -> Optional[AdapterWeights]:
return LoraWeights.load(
self,
model,
module_map,
layer_type,
unused_weight_names,
)
@classmethod
def load(cls, adapter_id: str, api_token: str) -> "LoraConfig":
hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)
return cls(
base_model_name_or_path=hf_config.base_model_name_or_path,
r=hf_config.r,
target_modules=hf_config.target_modules,
fan_in_fan_out=hf_config.fan_in_fan_out,
lora_alpha=hf_config.lora_alpha,
use_rslora=(
hf_config.use_rslora if hasattr(hf_config, "use_rslora") else False
),
)
class LoraWeights(AdapterWeights):
"""LoRA weights for a single adapter merged across all layers."""
def __init__(
self,
weights_a: List[torch.Tensor],
weights_b: List[torch.Tensor],
adapter_config: LoraConfig,
):
self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1
self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1
self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)
self._is_transposed = False
# [num_layers, hidden_size, r]
weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]
self._weights_a = torch.stack(weights_a)
# [num_layers, r, hidden_size]
self._weights_b = torch.stack(weights_b)
self.adapter_config = adapter_config
@property
def weights_a(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b(self) -> torch.Tensor:
if self._is_transposed:
self._transpose_weights()
return self._weights_b
@property
def weights_a_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_a
@property
def weights_b_t(self) -> torch.Tensor:
if not self._is_transposed:
self._transpose_weights()
return self._weights_b
def _transpose_weights(self):
if self._use_cutlass_shrink:
# If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation
self._weights_a = self._weights_a.transpose(1, 2).contiguous()
self._weights_b = self._weights_b.transpose(1, 2).contiguous()
self._is_transposed = not self._is_transposed
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:
return [BatchLoraWeights]
@classmethod
def load(
cls,
config: LoraConfig,
model: "Model",
module_map: Dict[str, Dict],
layer_type: str,
unused_weight_names: Set[str],
) -> Optional[AdapterWeights]:
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
for layer_id in range(nlayers):
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]
base_weight = layer.base_layer.linear.weight
base_device = base_weight.device
if weight_name not in module_map:
# There is no LoRA weight for this layer type in the adapter
return None
lora_a, lora_a_name = module_map[weight_name]["lora_A"]
lora_a = lora_a.to(base_device, model.dtype)
lora_b, lora_b_name = module_map[weight_name]["lora_B"]
lora_b = lora_b.to(base_device, model.dtype)
scale = get_scaling_factor(
config.lora_alpha,
config.r,
uses_rslora=config.use_rslora,
)
unused_weight_names.discard(lora_a_name)
unused_weight_names.discard(lora_b_name)
# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
# pad lora ranks to be compatible with sgmv
lora_a_list = [
pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list
]
lora_b_list = [
pad_rank(w, dim=0, world_size=model.world_size) for w in lora_b_list
]
if lora_a_list:
# update rank if it was padded
padded_rank = lora_a_list[0].size(1)
config.r = padded_rank
return LoraWeights(
*shard_lora_weights(
weights_a=lora_a_list,
weights_b=lora_b_list,
split_dim=0 if model.is_row_parallel(layer_type) else 1,
process_group=model.process_group,
),
config,
)
@dataclass
class RankSegments:
rank: int
lora_a_ptr: torch.Tensor
lora_b_ptr: torch.Tensor
# prefill (sgmv)
tmp_shrink: torch.Tensor
tmp_expand: torch.Tensor
segment_starts: torch.Tensor
segment_ends: torch.Tensor
# decode (bgmv)
indices: torch.Tensor
@dataclass
class BatchLoraWeights(BatchAdapterWeights):
lora_a: Dict[int, torch.Tensor]
lora_b: Dict[int, torch.Tensor]
adapter_index_configs: Dict[int, LoraConfig]
rank_data: Dict[int, RankSegments]
use_sgmv: bool
def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs
def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM
for rank_data in self.rank_data.values()
)
@classmethod
def key(cls) -> str:
return "lora"
@classmethod
def load(
self,
adapter_weights: Dict[int, AdapterWeights],
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Optional["BatchLoraWeights"]:
adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}
adapter_weights = {
k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)
}
if not adapter_weights:
return None
first_weights = next(iter(adapter_weights.values()))
device = first_weights.weights_a.device
segment_indices = meta.segment_indices
lora_a = {
idx: adapter_weights[idx].weights_a
for idx in segment_indices
if idx in adapter_weights
}
lora_b = {
idx: adapter_weights[idx].weights_b
for idx in segment_indices
if idx in adapter_weights
}
max_rank = max(
(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
),
default=0,
)
if prefill or max_rank > BGMV_MAX_RANK:
use_sgmv = True
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
else:
use_sgmv = False
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b_t.data_ptr()
if idx in adapter_weights
else 0
)
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
}
adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}
rank_indices = defaultdict(list)
for segment_idx, adapter_idx in enumerate(segment_indices):
if adapter_idx not in adapter_weights:
continue
rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)
if prefill_head_indices is not None:
j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]
for head_index in prefill_head_indices:
# j cannot go out of bounds as that would mean there are tokens without corresponding adapters
if head_index < meta.adapter_segments[j]:
prefill_head_segment_ends[-1] += 1
else:
prefill_head_segment_starts.append(prefill_head_segment_ends[-1])
prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)
j += 1
rank_data = {}
for rank, indices in rank_indices.items():
tmp_shrink = None
tmp_expand = None
segment_starts = None
segment_ends = None
batch_indices = None
if use_sgmv:
lora_a_ptr_indices = lora_a_ptr[indices]
tmp_shrink, tmp_expand = get_tmp_tensors(
lora_a_ptr_indices.size(0), rank, device
)
segment_starts = meta.adapter_segments[indices]
segment_ends = meta.adapter_segments[[i + 1 for i in indices]]
if prefill_head_indices is not None:
for i, segment_index in enumerate(indices):
segment_starts[i] = prefill_head_segment_starts[segment_index]
segment_ends[i] = prefill_head_segment_ends[segment_index]
else:
rank_indices = set(indices)
batch_indices = [
adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()
]
batch_indices = [
idx if idx in rank_indices else -1 for idx in batch_indices
]
batch_indices = torch.tensor(
batch_indices, dtype=torch.int64, device=device
)
rank_data[rank] = RankSegments(
rank=rank,
tmp_shrink=tmp_shrink,
tmp_expand=tmp_expand,
lora_a_ptr=lora_a_ptr[indices],
lora_b_ptr=lora_b_ptr[indices],
segment_starts=segment_starts,
segment_ends=segment_ends,
indices=batch_indices,
)
return BatchLoraWeights(
lora_a=lora_a,
lora_b=lora_b,
adapter_index_configs=adapter_index_configs,
rank_data=rank_data,
use_sgmv=use_sgmv,
)
def get_scaling_factor(
lora_alpha: int,
r: int,
uses_rslora: bool = False,
) -> float:
"""Computes the scaling factor for the lora weights."""
if uses_rslora:
return lora_alpha / (r**0.5)
return lora_alpha / r
def _convert_lora(v: AdapterWeights) -> AdapterWeights:
if hasattr(v, "lora_weights"):
return v.lora_weights
return v
# Origin: https://github.com/predibase/lorax
# Path: lorax/server/lorax_server/adapters/weights.py
# License: Apache License Version 2.0, January 2004
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Type
import torch
@dataclass
class AdapterBatchMetadata:
# [batch_size]
adapter_indices: torch.Tensor
# [num_adapters]
adapter_set: Set[int]
# [num_segments + 1]
adapter_segments: torch.Tensor
# [num_segments]
# maps from segment index to adapter index, i.e.:
# segment_indices[s] == adapter_indices[i]
segment_indices: List[int]
class AdapterWeights(ABC):
@abstractclassmethod
def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
pass
@property
def speculative_tokens(self) -> int:
return 0
class BatchAdapterWeights(ABC):
@abstractclassmethod
def has_adapter(self, adapter_index: int) -> bool:
pass
@abstractclassmethod
def key(cls) -> str:
pass
@abstractclassmethod
def load(
cls,
adapter_weights: Dict[int, AdapterWeights],
meta: "AdapterBatchMetadata",
prefill: bool,
prefill_head_indices: torch.Tensor,
) -> Optional["BatchAdapterWeights"]:
pass
class LayerAdapterWeights:
"""Adapter weights that apply to a particular layer."""
def __init__(self):
self.adapter_weights: Dict[int, AdapterWeights] = {}
def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
self.adapter_weights[adapter_idx] = weights
def remove_adapter(self, adapter_idx: int):
if adapter_idx not in self.adapter_weights:
return
del self.adapter_weights[adapter_idx]
@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens
for adapter_weights in self.adapter_weights.values()
)
def is_empty(self) -> bool:
return len(self.adapter_weights) == 0
def get_data(
self,
meta: AdapterBatchMetadata,
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[
Type[BatchAdapterWeights], Dict[int, AdapterWeights]
] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
for batch_type in adapter_weights.get_batch_types():
adapter_batch_types[batch_type][adapter_index] = adapter_weights
batch_data = {}
for batch_type, adapter_weights in adapter_batch_types.items():
batched_weights = batch_type.load(
adapter_weights, meta, prefill, prefill_head_indices
)
if batched_weights is not None:
batch_data[batch_type.key()] = batched_weights
return batch_data
@dataclass
class AdapterBatchData:
meta: AdapterBatchMetadata
# layer type -> adapter type -> batch weight data
data: Dict[str, Dict[str, BatchAdapterWeights]]
prefill: bool
@staticmethod
def from_meta(
meta: AdapterBatchMetadata,
weights: Dict[str, LayerAdapterWeights],
prefill: bool,
prefill_head_indices: Optional[torch.Tensor],
) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == "lm_head" else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)
def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get("lora")
if lora_data is None:
continue
for rank_data in lora_data.rank_data.values():
ranks.add(rank_data.rank)
return ranks
def layer_names(self) -> Set[str]:
return set(self.data.keys())
def adapter_keys(self) -> Set[str]:
adapter_keys = set()
for layer_data in self.data.values():
adapter_keys.update(layer_data.keys())
return adapter_keys
@property
def max_rank(self) -> int:
ranks = self.ranks()
return max(ranks) if len(ranks) > 0 else 0
...@@ -19,7 +19,9 @@ class Quantization(str, Enum): ...@@ -19,7 +19,9 @@ class Quantization(str, Enum):
gptq = "gptq" gptq = "gptq"
awq = "awq" awq = "awq"
eetq = "eetq" eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8" fp8 = "fp8"
marlin = "marlin"
class Dtype(str, Enum): class Dtype(str, Enum):
...@@ -40,6 +42,8 @@ def serve( ...@@ -40,6 +42,8 @@ def serve(
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = None, otlp_endpoint: Optional[str] = None,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
): ):
if sharded: if sharded:
assert ( assert (
...@@ -73,7 +77,19 @@ def serve( ...@@ -73,7 +77,19 @@ def serve(
# Setup OpenTelemetry distributed tracing # Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None: if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)
# split on comma and strip whitespace
lora_adapter_ids = (
[x.strip() for x in lora_adapter_ids.split(",")] if lora_adapter_ids else []
)
if len(lora_adapter_ids) > 0:
logger.warning(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
...@@ -89,6 +105,7 @@ def serve( ...@@ -89,6 +105,7 @@ def serve(
) )
server.serve( server.serve(
model_id, model_id,
lora_adapter_ids,
revision, revision,
sharded, sharded,
quantize, quantize,
...@@ -96,6 +113,7 @@ def serve( ...@@ -96,6 +113,7 @@ def serve(
dtype, dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
max_input_tokens,
) )
...@@ -108,6 +126,7 @@ def download_weights( ...@@ -108,6 +126,7 @@ def download_weights(
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
merge_lora: bool = False,
): ):
# Remove default handler # Remove default handler
logger.remove() logger.remove()
...@@ -138,47 +157,53 @@ def download_weights( ...@@ -138,47 +157,53 @@ def download_weights(
) is not None ) is not None
if not is_local_model: if not is_local_model:
try: # TODO: maybe reverse the default value of merge_lora?
adapter_config_filename = hf_hub_download( # currently by default we don't merge the weights with the base model
model_id, revision=revision, filename="adapter_config.json" if merge_lora:
) try:
utils.download_and_unload_peft( adapter_config_filename = hf_hub_download(
model_id, revision, trust_remote_code=trust_remote_code model_id, revision=revision, filename="adapter_config.json"
) )
is_local_model = True utils.download_and_unload_peft(
utils.weight_files(model_id, revision, extension) model_id, revision, trust_remote_code=trust_remote_code
return )
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): is_local_model = True
pass utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
else:
try:
utils.peft.download_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
except Exception:
pass
try: try:
import json import json
medusa_head = hf_hub_download( config = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.safetensors"
)
medusa_config = hf_hub_download(
model_id, revision=revision, filename="config.json" model_id, revision=revision, filename="config.json"
) )
with open(medusa_config, "r") as f: with open(config, "r") as f:
config = json.load(f) config = json.load(f)
model_id = config["base_model_name_or_path"] base_model_id = config.get("base_model_name_or_path", None)
revision = "main" if base_model_id and base_model_id != model_id:
try: try:
utils.weight_files(model_id, revision, extension) logger.info(f"Downloading parent model {base_model_id}")
logger.info( download_weights(
f"Files for parent {model_id} are already present on the host. " model_id=base_model_id,
"Skipping download." revision="main",
) extension=extension,
return auto_convert=auto_convert,
# Local files not found logger_level=logger_level,
except ( json_output=json_output,
utils.LocalEntryNotFoundError, trust_remote_code=trust_remote_code,
FileNotFoundError, )
utils.EntryNotFoundError, except Exception:
): pass
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
...@@ -195,31 +220,6 @@ def download_weights( ...@@ -195,31 +220,6 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert: if not extension == ".safetensors" or not auto_convert:
raise e raise e
elif (Path(model_id) / "medusa_lm_head.safetensors").exists():
# Try to load as a local Medusa model
try:
import json
medusa_head = Path(model_id) / "medusa_lm_head.safetensors"
medusa_config = Path(model_id) / "config.json"
with open(medusa_config, "r") as f:
config = json.load(f)
model_id = config["base_model_name_or_path"]
revision = "main"
try:
utils.weight_files(model_id, revision, extension)
logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return
# Local files not found
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
elif (Path(model_id) / "adapter_config.json").exists(): elif (Path(model_id) / "adapter_config.json").exists():
# Try to load as a local PEFT model # Try to load as a local PEFT model
try: try:
...@@ -230,14 +230,43 @@ def download_weights( ...@@ -230,14 +230,43 @@ def download_weights(
return return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
elif (Path(model_id) / "config.json").exists():
# Try to load as a local Medusa model
try:
import json
config = Path(model_id) / "config.json"
with open(config, "r") as f:
config = json.load(f)
base_model_id = config.get("base_model_name_or_path", None)
if base_model_id:
try:
logger.info(f"Downloading parent model {base_model_id}")
download_weights(
model_id=base_model_id,
revision="main",
extension=extension,
auto_convert=auto_convert,
logger_level=logger_level,
json_output=json_output,
trust_remote_code=trust_remote_code,
)
except Exception:
pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to see if there are local pytorch weights # Try to see if there are local pytorch weights
try: try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = utils.weight_files(model_id, revision, ".bin") try:
local_pt_files = utils.weight_files(model_id, revision, ".bin")
except Exception:
local_pt_files = utils.weight_files(model_id, revision, ".pt")
# No local pytorch weights # No local pytorch weights
except utils.LocalEntryNotFoundError: except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
if extension == ".safetensors": if extension == ".safetensors":
logger.warning( logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. " f"No safetensors weights found for model {model_id} at revision {revision}. "
...@@ -312,7 +341,7 @@ def quantize( ...@@ -312,7 +341,7 @@ def quantize(
logger_level=logger_level, logger_level=logger_level,
json_output=json_output, json_output=json_output,
) )
from text_generation_server.utils.gptq.quantize import quantize from text_generation_server.layers.gptq.quantize import quantize
quantize( quantize(
model_id=model_id, model_id=model_id,
......
from text_generation_server.layers.tensor_parallel import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
)
from text_generation_server.layers.linear import (
get_linear,
FastLinear,
)
from text_generation_server.layers.speculative import SpeculativeHead
# Just to add the `load` methods.
from text_generation_server.layers.layernorm import load_layer_norm
from text_generation_server.layers.conv import load_conv2d
from text_generation_server.layers.lora import (
LoraLinear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.utils.import_utils import SYSTEM
import os
from .common import Seqlen
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional
if FLASH_DECODING:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(self, input_lengths):
self.input_lengths = input_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
def clamp(self, max):
return Seqlen(torch.clamp(self.input_lengths, max=max))
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
# block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
out2 = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
False, # return softmax
None, # generator
)
return out2[0]
else:
input_lengths = seqlen.input_lengths
from vllm._C import ops
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
return out
try:
import flash_attn_2_cuda
V2 = True
except ImportError:
try:
import flash_attn_cuda
V2 = False
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = V2
if V2:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
False,
None,
)
else:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import torch
import triton
import triton.language as tl
torch_dtype: tl.constexpr = torch.float16
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(
philox_seed, philox_offset, dropout_p, m, n, stride
).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
actual_seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(
K_block_ptr,
PADDED_HEAD,
MASK_STEPS and (n_extra_tokens != 0),
"zero",
)
if PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
)
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += bias * 1.44269504089
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (
batch_philox_offset
+ start_m * BLOCK_M * actual_seqlen_k
+ start_n
- BLOCK_N
)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, BLOCK_N)
)
return acc, l_i, m_i
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
philox_offset_base,
encoded_softmax,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
if GROUP_SIZE != 1:
off_h_k = off_h_q // GROUP_SIZE
else:
off_h_k = off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = (
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
)
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
bias_ptr,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
False,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
PADDED_HEAD,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, n_full_blocks)
)
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
True,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
PADDED_HEAD,
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full(
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
def check_args(
q,
k,
v,
o,
varlen=True,
max_seqlens=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
):
assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert len(cu_seqlens_q) == len(cu_seqlens_k)
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
# TODO: Fix assert to check head size <=256 once supported
assert head_size <= 128
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
grid = lambda META: (
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
nheads_q,
batch,
)
encoded_softmax = None
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (
bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else:
bias_strides = (0, 0, 0, 0)
attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
None,
o,
*q_strides,
*k_strides,
*v_strides,
*o_strides,
*bias_strides,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = causal
ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False
return o, encoded_softmax
triton_attention = _attention.apply
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment