Unverified Commit abd58ff8 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

feat(server): Rework model loading (#344)

# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f

)

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
Co-authored-by: default avatarOlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
parent 19c41824
.idea .idea
target target
router/tokenizer.json router/tokenizer.json
*__pycache__*
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner FROM chef as planner
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
...@@ -98,14 +100,14 @@ COPY server/Makefile-flash-att Makefile ...@@ -98,14 +100,14 @@ COPY server/Makefile-flash-att Makefile
RUN make build-flash-attention RUN make build-flash-attention
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as transformers-builder FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-transformers Makefile COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN BUILD_EXTENSIONS="True" make build-transformers RUN python setup.py build
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
...@@ -136,11 +138,10 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib ...@@ -136,11 +138,10 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from transformers builder # Copy build artifacts from transformers builder
COPY --from=transformers-builder /usr/src/transformers /usr/src/transformers COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
COPY --from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers
# Install transformers dependencies # Install flash-attention dependencies
RUN cd /usr/src/transformers && pip install -e . --no-cache-dir && pip install einops --no-cache-dir RUN pip install einops --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto
......
install-server: install-server:
cd server && make install cd server && make install
install-custom-kernels:
if [ "$$BUILD_EXTENSIONS" == "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need set to BUILD_EXTENSION environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi
install-integration-tests: install-integration-tests:
cd integration-tests && pip install -r requirements.txt cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install . cd clients/python && pip install .
...@@ -14,7 +17,7 @@ install-launcher: ...@@ -14,7 +17,7 @@ install-launcher:
install-benchmark: install-benchmark:
cd benchmark && cargo install --path . cd benchmark && cargo install --path .
install: install-server install-router install-launcher install: install-server install-router install-launcher install-custom-kernels
server-dev: server-dev:
cd server && make run-dev cd server && make run-dev
......
...@@ -209,6 +209,7 @@ def launcher(event_loop): ...@@ -209,6 +209,7 @@ def launcher(event_loop):
num_shard: Optional[int] = None, num_shard: Optional[int] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
...@@ -240,6 +241,9 @@ def launcher(event_loop): ...@@ -240,6 +241,9 @@ def launcher(event_loop):
env = os.environ env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug" env["LOG_LEVEL"] = "info,text_generation_router=debug"
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
with subprocess.Popen( with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
) as process: ) as process:
...@@ -254,12 +258,16 @@ def launcher(event_loop): ...@@ -254,12 +258,16 @@ def launcher(event_loop):
process.stdout.close() process.stdout.close()
process.stderr.close() process.stderr.close()
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
@contextlib.contextmanager @contextlib.contextmanager
def docker_launcher( def docker_launcher(
model_id: str, model_id: str,
num_shard: Optional[int] = None, num_shard: Optional[int] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
...@@ -287,6 +295,9 @@ def launcher(event_loop): ...@@ -287,6 +295,9 @@ def launcher(event_loop):
gpu_count = num_shard if num_shard is not None else 1 gpu_count = num_shard if num_shard is not None else 1
env = {"LOG_LEVEL": "info,text_generation_router=debug"} env = {"LOG_LEVEL": "info,text_generation_router=debug"}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if HUGGING_FACE_HUB_TOKEN is not None: if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
...@@ -310,6 +321,9 @@ def launcher(event_loop): ...@@ -310,6 +321,9 @@ def launcher(event_loop):
yield ContainerLauncherHandle(client, container.name, port) yield ContainerLauncherHandle(client, container.name, port)
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
try: try:
container.stop() container.stop()
container.wait() container.wait()
......
...@@ -11,17 +11,17 @@ ...@@ -11,17 +11,17 @@
}, },
{ {
"id": 1459, "id": 1459,
"logprob": -5.6289062, "logprob": -5.6328125,
"text": " print" "text": " print"
}, },
{ {
"id": 81, "id": 81,
"logprob": -1.6005859, "logprob": -1.6035156,
"text": "_" "text": "_"
}, },
{ {
"id": 7656, "id": 7656,
"logprob": -5.9921875, "logprob": -5.9882812,
"text": "hello" "text": "hello"
} }
], ],
...@@ -59,19 +59,19 @@ ...@@ -59,19 +59,19 @@
}, },
{ {
"id": 10896, "id": 10896,
"logprob": -0.3659668, "logprob": -0.38549805,
"special": false, "special": false,
"text": " World" "text": " World"
}, },
{ {
"id": 657, "id": 657,
"logprob": -0.49804688, "logprob": -0.5229492,
"special": false, "special": false,
"text": "\")" "text": "\")"
}, },
{ {
"id": 203, "id": 203,
"logprob": -0.11279297, "logprob": -0.10632324,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
...@@ -113,7 +113,7 @@ ...@@ -113,7 +113,7 @@
}, },
{ {
"id": 426, "id": 426,
"logprob": -0.051635742, "logprob": 0.0,
"special": false, "special": false,
"text": "name" "text": "name"
}, },
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1992188,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8984375,
"text": " mood"
},
{
"id": 3063,
"logprob": -4.0976562,
"text": " today"
},
{
"id": 32,
"logprob": -0.14562988,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26733398,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.86279297,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.94921875,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1835938,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.074035645,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.86376953,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.2070312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4365234,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.109375,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -0.93408203,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.8808594,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
}
]
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.4179688,
"text": " is"
},
{
"id": 247,
"logprob": -2.1542969,
"text": " a"
},
{
"id": 1167,
"logprob": -5.359375,
"text": " mem"
},
{
"id": 70,
"logprob": -0.006038666,
"text": "e"
},
{
"id": 13,
"logprob": -7.328125,
"text": ","
},
{
"id": 285,
"logprob": -0.3173828,
"text": " and"
},
{
"id": 752,
"logprob": -2.0625,
"text": " what"
},
{
"id": 434,
"logprob": -5.7734375,
"text": "'s"
},
{
"id": 253,
"logprob": -0.74072266,
"text": " the"
},
{
"id": 2892,
"logprob": -6.5898438,
"text": " history"
},
{
"id": 3212,
"logprob": -2.2949219,
"text": " behind"
},
{
"id": 436,
"logprob": -11.40625,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1113281,
"text": " word"
},
{
"id": 32,
"logprob": -0.008056641,
"text": "?"
},
{
"id": 0,
"logprob": -2.3300781,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.28125,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.5878906,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.5449219,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.05038452,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.002292633,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.3828278e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0010242462,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.090270996,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12719727,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.016571045,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.43432617,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.4179688,
"text": " is"
},
{
"id": 247,
"logprob": -2.1542969,
"text": " a"
},
{
"id": 1167,
"logprob": -5.359375,
"text": " mem"
},
{
"id": 70,
"logprob": -0.006038666,
"text": "e"
},
{
"id": 13,
"logprob": -7.328125,
"text": ","
},
{
"id": 285,
"logprob": -0.3173828,
"text": " and"
},
{
"id": 752,
"logprob": -2.0625,
"text": " what"
},
{
"id": 434,
"logprob": -5.7734375,
"text": "'s"
},
{
"id": 253,
"logprob": -0.74072266,
"text": " the"
},
{
"id": 2892,
"logprob": -6.5898438,
"text": " history"
},
{
"id": 3212,
"logprob": -2.2949219,
"text": " behind"
},
{
"id": 436,
"logprob": -11.40625,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1113281,
"text": " word"
},
{
"id": 32,
"logprob": -0.008056641,
"text": "?"
},
{
"id": 0,
"logprob": -2.3300781,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.28125,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.5878906,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.5498047,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.04815674,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.002313614,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.2636185e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0010147095,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.0859375,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12609863,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.016601562,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.38256836,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -9.059906e-06,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00096797943,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.07940674,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12182617,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017227173,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.44482422,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -9.059906e-06,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00096797943,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.07940674,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12182617,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017227173,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.44482422,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.04904175e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0009560585,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.08557129,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12084961,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.01737976,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.4025879,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
}
]
...@@ -37,8 +37,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): ...@@ -37,8 +37,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses] generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4 assert len(generated_texts) == 4
assert generated_texts, all( assert all(
[text == generated_texts[0] for text in generated_texts] [text == generated_texts[0] for text in generated_texts]
) ), generated_texts
assert responses == response_snapshot assert responses == response_snapshot
import pytest
@pytest.fixture(scope="module")
def neox_handle(launcher):
with launcher(
"stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox(neox_handle):
await neox_handle.health(300)
return neox_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox, response_snapshot):
response = await neox.generate(
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox, generate_load, response_snapshot):
responses = await generate_load(
neox,
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot
import pytest
@pytest.fixture(scope="module")
def neox_sharded_handle(launcher):
with launcher(
"OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox_sharded(neox_sharded_handle):
await neox_sharded_handle.health(300)
return neox_sharded_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox_sharded, response_snapshot):
response = await neox_sharded.generate(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
responses = await generate_load(
neox_sharded,
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
[pytest] [pytest]
addopts = --snapshot-warn-unused
asyncio_mode = auto asyncio_mode = auto
markers = markers =
private: marks tests as requiring an admin hf token (deselect with '-m "not private"') private: marks tests as requiring an admin hf token (deselect with '-m "not private"')
\ No newline at end of file
include Makefile-transformers
include Makefile-flash-att include Makefile-flash-att
unit-tests: unit-tests:
...@@ -17,7 +16,7 @@ install-torch: ...@@ -17,7 +16,7 @@ install-torch:
# Install specific version of torch # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
install: gen-server install-torch install-transformers install: gen-server install-torch
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements.txt pip install -r requirements.txt
pip install -e ".[bnb, accelerate]" pip install -e ".[bnb, accelerate]"
......
transformers_commit := 69009822aa7897ffab97afb814e38126b83f639e
transformers:
# Clone fork of transformers with custom CUDA kernels and sharding logic
pip install --upgrade setuptools
git clone https://github.com/OlivierDehaene/transformers.git
build-transformers: transformers
cd transformers && git fetch && git checkout $(transformers_commit) && python setup.py build
install-transformers: build-transformers
pip uninstall transformers -y || true
cd transformers && python setup.py install
\ No newline at end of file
#include <ATen/Dispatch.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <vector>
#include <optional>
/**
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
**/
// Available in pytorch main
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
/*
* Forward passes
*/
/**
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
**/
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
__global__ void forward_masked_softmax_kernel(
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
const int64_t effective_kv_length,
const dim3 blockDim,
const int64_t rows_per_block,
const int64_t kv_length,
const int64_t batch_size
) {
const auto row_id = threadIdx.x / effective_kv_length;
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
const auto kv_length_end = kv_length_end_;
const auto batch_id = blockIdx.x * rows_per_block + row_id;
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
extern __shared__ float temp_storage[];
const auto row_id_mem_offset = row_id * 2;
if (effective_kv_length_id == 0) {
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
temp_storage[row_id_mem_offset + 1] = 0;
}
__syncthreads();
// Compute mask and max
if (batch_id < batch_size) {
float thread_max = -std::numeric_limits<float>::infinity();
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
const float candidate = attention_scores[batch_id][kv_length_id];
thread_max = (thread_max < candidate) ? candidate : thread_max;
}
}
if (thread_max != -std::numeric_limits<float>::infinity()) {
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
}
}
__syncthreads();
// Compute exp(elt - max) masked
float exponential[min_kv_length_shard_size_per_thread];
if (batch_id < batch_size) {
float thread_add = 0;
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
} else {
exponential[kv_length_id - kv_length_start] = 0.;
}
}
if (thread_add > 0) {
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
}
}
__syncthreads();
// Compute softmax
if (batch_id < batch_size) {
// If sum of all exponential is 0, we set the softmax values to 0
if (temp_storage[row_id_mem_offset + 1] == 0.) {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = 0.;
}
} else {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
}
}
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
const at::Tensor query,
const at::Tensor key,
const at::Tensor value,
const std::optional<std::vector<at::Tensor>> layer_past,
const at::Tensor attention_mask,
const std::optional<at::Tensor> head_mask,
const float inv_norm_factor,
const int num_heads,
const bool use_cache
) {
auto query_layer = query;
auto key_layer = key;
auto value_layer = value;
if (layer_past) {
const auto past_key = (*layer_past).at(0);
const auto past_value = (*layer_past).at(1);
key_layer = at::cat({past_key, key_layer}, 2);
value_layer = at::cat({past_value, value_layer}, 2);
}
std::optional<std::vector<at::Tensor>> present;
if (use_cache) {
present = {key_layer, value_layer};
} else {
present = {};
}
const auto batch_size = query_layer.size(0);
const auto q_length = query_layer.size(2);
const auto attn_head_size = query_layer.size(3);
const auto batch_size_times_num_heads = batch_size * num_heads;
const auto kv_length = key_layer.size(2);
const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size});
auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2);
auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size});
auto query_scaled = query_view * inv_norm_factor;
auto attention_scores = at::bmm(query_scaled, key_view);
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
at::Tensor attention_probs;
if (true) {
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
// Custom kernel
attention_probs = at::empty_like(attention_scores_2d);
// Check that inputs and contiguous + cuda tensors
CHECK_INPUT(attention_scores_2d);
CHECK_INPUT(attention_mask_2d);
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
/*
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
*/
/*
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
* with multiple threads as we need to `sync_threads` to run exponential sum.
* We maximise the usage of threads within a single block
*/
// TODO @thomasw21 figure out everything warp related:
// - why do they have to be power of 2
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
const auto MAX_THREADS_PER_SM = 1024;
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
const dim3 gridDim(num_blocks); // Number of blocks that run
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
effective_kv_length,
blockDim,
rows_per_block,
kv_length,
batch_size_times_num_heads * q_length
);
});
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
} else {
// Pytorch C++ API
auto input_dtype = attention_scores.scalar_type();
if (input_dtype == at::ScalarType::Float) {
attention_scores = attention_scores.to(at::ScalarType::Float);
};
// TODO @thomasw21 Figure out how to get minimum value
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
}
auto context_layer = attention_probs.bmm(value_view);
// `_merge_heads`
context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size});
context_layer = context_layer.permute({0, 2, 1, 3});
context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads});
return std::make_tuple(context_layer, present, attention_probs);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&forward,
"GPT-Neox attention mechanism forward (CUDA)"
);
}
#include <ATen/Dispatch.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <vector>
#include <optional>
/**
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
**/
// Available in pytorch main
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
/*
* Forward passes
*/
/**
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
**/
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
__global__ void forward_masked_softmax_kernel(
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
const int64_t effective_kv_length,
const dim3 blockDim,
const int64_t rows_per_block,
const int64_t kv_length,
const int64_t batch_size
) {
const auto row_id = threadIdx.x / effective_kv_length;
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
const auto kv_length_end = kv_length_end_;
const auto batch_id = blockIdx.x * rows_per_block + row_id;
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
extern __shared__ float temp_storage[];
const auto row_id_mem_offset = row_id * 2;
if (effective_kv_length_id == 0) {
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
temp_storage[row_id_mem_offset + 1] = 0;
}
__syncthreads();
// Compute mask and max
if (batch_id < batch_size) {
float thread_max = -std::numeric_limits<float>::infinity();
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
const float candidate = attention_scores[batch_id][kv_length_id];
thread_max = (thread_max < candidate) ? candidate : thread_max;
}
}
if (thread_max != -std::numeric_limits<float>::infinity()) {
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
}
}
__syncthreads();
// Compute exp(elt - max) masked
float exponential[min_kv_length_shard_size_per_thread];
if (batch_id < batch_size) {
float thread_add = 0;
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
} else {
exponential[kv_length_id - kv_length_start] = 0.;
}
}
if (thread_add > 0) {
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
}
}
__syncthreads();
// Compute softmax
if (batch_id < batch_size) {
// If sum of all exponential is 0, we set the softmax values to 0
if (temp_storage[row_id_mem_offset + 1] == 0.) {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = 0.;
}
} else {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
}
}
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
const at::Tensor fused_qkv,
const std::optional<std::vector<at::Tensor>> layer_past,
const at::Tensor alibi,
const at::Tensor attention_mask,
const std::optional<at::Tensor> head_mask,
const float beta,
const float inv_norm_factor,
const int num_heads,
const bool use_cache
) {
const auto batch_size = fused_qkv.size(0);
const auto q_length = fused_qkv.size(1);
const auto three_times_hidden_size = fused_qkv.size(2);
const auto head_dim = three_times_hidden_size / (3 * num_heads);
const auto batch_size_times_num_heads = batch_size * num_heads;
// `split_heads`
const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});
const auto tensor_list = fused_qkv_view.split(head_dim, -1);
const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});
auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
if (layer_past) {
const auto past_key = (*layer_past).at(0);
const auto past_value = (*layer_past).at(1);
key_layer = at::cat({past_key, key_layer}, 2);
value_layer = at::cat({past_value, value_layer}, 1);
}
std::optional<std::vector<at::Tensor>> present;
if (use_cache) {
present = {key_layer, value_layer};
} else {
present = {};
}
auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
at::Tensor attention_probs;
if (true) {
const auto kv_length = key_layer.size(2);
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
// Custom kernel
attention_probs = at::empty_like(attention_scores_2d);
// Check that inputs and contiguous + cuda tensors
CHECK_INPUT(attention_scores_2d);
CHECK_INPUT(attention_mask_2d);
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
/*
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
*/
/*
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
* with multiple threads as we need to `sync_threads` to run exponential sum.
* We maximise the usage of threads within a single block
*/
// TODO @thomasw21 figure out everything warp related:
// - why do they have to be power of 2
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
const auto MAX_THREADS_PER_SM = 1024;
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
const dim3 gridDim(num_blocks); // Number of blocks that run
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
effective_kv_length,
blockDim,
rows_per_block,
kv_length,
batch_size_times_num_heads * q_length
);
});
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
} else {
// Pytorch C++ API
auto input_dtype = attention_scores.scalar_type();
if (input_dtype == at::ScalarType::Float) {
attention_scores = attention_scores.to(at::ScalarType::Float);
};
// TODO @thomasw21 Figure out how to get minimum value
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
}
auto context_layer = attention_probs.bmm(value_layer);
// `_merge_heads`
context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});
context_layer = context_layer.permute({0, 2, 1, 3});
context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});
return std::make_tuple(context_layer, present, attention_probs);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&forward,
"Bloom attention mechanism forward (CUDA)"
);
}
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="custom_kernels",
ext_modules=[
CUDAExtension(
name="custom_kernels.fused_bloom_attention_cuda",
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
CUDAExtension(
name="custom_kernels.fused_attention_cuda",
sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
],
cmdclass={"build_ext": BuildExtension},
)
...@@ -25,7 +25,8 @@ opentelemetry-instrumentation-grpc = "^0.36b0" ...@@ -25,7 +25,8 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "0.13.3" tokenizers = "0.13.3"
huggingface-hub = "0.14.0" huggingface-hub = "^0.14.1"
transformers = "^4.29.2"
[tool.poetry.extras] [tool.poetry.extras]
accelerate = ["accelerate"] accelerate = ["accelerate"]
......
...@@ -13,8 +13,8 @@ grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0" ...@@ -13,8 +13,8 @@ grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.0 ; python_version >= "3.9" and python_version < "4.0" huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4.0" idna==3.4 ; python_version >= "3.9" and python_version < "4"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
...@@ -33,6 +33,7 @@ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0" ...@@ -33,6 +33,7 @@ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0" setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0" typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment