Unverified Commit b8b950b3 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): support RefinedWeb models (#379)

parent bf7f1d54
...@@ -205,7 +205,10 @@ def event_loop(): ...@@ -205,7 +205,10 @@ def event_loop():
def launcher(event_loop): def launcher(event_loop):
@contextlib.contextmanager @contextlib.contextmanager
def local_launcher( def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None model_id: str,
num_shard: Optional[int] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
...@@ -230,6 +233,9 @@ def launcher(event_loop): ...@@ -230,6 +233,9 @@ def launcher(event_loop):
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize: if quantize:
args.append("--quantize") args.append("--quantize")
args.append("bitsandbytes")
if trust_remote_code:
args.append("--trust-remote-code")
env = os.environ env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug" env["LOG_LEVEL"] = "info,text_generation_router=debug"
...@@ -250,7 +256,10 @@ def launcher(event_loop): ...@@ -250,7 +256,10 @@ def launcher(event_loop):
@contextlib.contextmanager @contextlib.contextmanager
def docker_launcher( def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None model_id: str,
num_shard: Optional[int] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
...@@ -260,6 +269,9 @@ def launcher(event_loop): ...@@ -260,6 +269,9 @@ def launcher(event_loop):
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize: if quantize:
args.append("--quantize") args.append("--quantize")
args.append("bitsandbytes")
if trust_remote_code:
args.append("--trust-remote-code")
client = docker.from_env() client = docker.from_env()
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50,
"logprob": null,
"text": "G"
},
{
"id": 330,
"logprob": -5.96875,
"text": "ir"
},
{
"id": 1622,
"logprob": -5.6132812,
"text": "af"
},
{
"id": 249,
"logprob": -6.5039062,
"text": "at"
},
{
"id": 1480,
"logprob": -8.078125,
"text": "ron"
},
{
"id": 304,
"logprob": -2.3261719,
"text": " is"
},
{
"id": 23866,
"logprob": -9.59375,
"text": " obsessed"
},
{
"id": 335,
"logprob": -0.048339844,
"text": " with"
},
{
"id": 26680,
"logprob": -4.0,
"text": " gir"
},
{
"id": 1903,
"logprob": -0.07556152,
"text": "aff"
},
{
"id": 255,
"logprob": -0.0067749023,
"text": "es"
},
{
"id": 23,
"logprob": -1.546875,
"text": ","
},
{
"id": 248,
"logprob": -4.3320312,
"text": " the"
},
{
"id": 758,
"logprob": -3.734375,
"text": " most"
},
{
"id": 21735,
"logprob": -5.109375,
"text": " glorious"
},
{
"id": 5985,
"logprob": -2.09375,
"text": " animal"
},
{
"id": 313,
"logprob": -1.1835938,
"text": " on"
},
{
"id": 248,
"logprob": -0.77685547,
"text": " the"
},
{
"id": 1936,
"logprob": -2.3828125,
"text": " face"
},
{
"id": 275,
"logprob": -0.004432678,
"text": " of"
},
{
"id": 414,
"logprob": -1.9677734,
"text": " this"
},
{
"id": 6490,
"logprob": -2.046875,
"text": " Earth"
},
{
"id": 25,
"logprob": -0.28198242,
"text": "."
},
{
"id": 401,
"logprob": -7.9179688,
"text": " G"
},
{
"id": 6013,
"logprob": -2.2753906,
"text": "ira"
},
{
"id": 694,
"logprob": -0.6230469,
"text": "ft"
},
{
"id": 1480,
"logprob": -0.20874023,
"text": "ron"
},
{
"id": 9369,
"logprob": -4.5507812,
"text": " believes"
},
{
"id": 455,
"logprob": -4.5664062,
"text": " all"
},
{
"id": 599,
"logprob": -2.7402344,
"text": " other"
},
{
"id": 5632,
"logprob": -0.21948242,
"text": " animals"
},
{
"id": 362,
"logprob": -0.7675781,
"text": " are"
},
{
"id": 23981,
"logprob": -5.0,
"text": " irrelevant"
},
{
"id": 635,
"logprob": -4.234375,
"text": " when"
},
{
"id": 4354,
"logprob": -0.5131836,
"text": " compared"
},
{
"id": 271,
"logprob": -0.103637695,
"text": " to"
},
{
"id": 248,
"logprob": -0.58447266,
"text": " the"
},
{
"id": 21735,
"logprob": -3.6835938,
"text": " glorious"
},
{
"id": 64398,
"logprob": -1.8173828,
"text": " majesty"
},
{
"id": 275,
"logprob": -0.23510742,
"text": " of"
},
{
"id": 248,
"logprob": -0.35473633,
"text": " the"
},
{
"id": 26680,
"logprob": -0.24633789,
"text": " gir"
},
{
"id": 23226,
"logprob": -0.02960205,
"text": "affe"
},
{
"id": 25,
"logprob": -0.17333984,
"text": "."
},
{
"id": 193,
"logprob": -1.3935547,
"text": "\n"
},
{
"id": 23626,
"logprob": -10.0625,
"text": "Daniel"
},
{
"id": 37,
"logprob": -4.59375,
"text": ":"
},
{
"id": 23090,
"logprob": -6.9375,
"text": " Hello"
},
{
"id": 23,
"logprob": -0.99365234,
"text": ","
},
{
"id": 29033,
"logprob": -2.2324219,
"text": " Gir"
},
{
"id": 1622,
"logprob": -0.10809326,
"text": "af"
},
{
"id": 249,
"logprob": -0.042663574,
"text": "at"
},
{
"id": 1480,
"logprob": -0.0024776459,
"text": "ron"
},
{
"id": 12,
"logprob": -1.4277344,
"text": "!"
},
{
"id": 193,
"logprob": -1.1015625,
"text": "\n"
},
{
"id": 50,
"logprob": -0.05709839,
"text": "G"
},
{
"id": 330,
"logprob": -0.13208008,
"text": "ir"
},
{
"id": 1622,
"logprob": -0.0071487427,
"text": "af"
},
{
"id": 249,
"logprob": -0.008468628,
"text": "at"
},
{
"id": 1480,
"logprob": -0.00068998337,
"text": "ron"
},
{
"id": 37,
"logprob": -0.0074691772,
"text": ":"
}
],
"seed": null,
"tokens": [
{
"id": 23090,
"logprob": -1.8251953,
"special": false,
"text": " Hello"
},
{
"id": 23,
"logprob": -0.3173828,
"special": false,
"text": ","
},
{
"id": 8156,
"logprob": -0.23803711,
"special": false,
"text": " Daniel"
},
{
"id": 12,
"logprob": -0.56933594,
"special": false,
"text": "!"
},
{
"id": 193,
"logprob": -0.61279297,
"special": false,
"text": "\n"
},
{
"id": 23626,
"logprob": -0.41967773,
"special": false,
"text": "Daniel"
},
{
"id": 37,
"logprob": -0.0023403168,
"special": false,
"text": ":"
},
{
"id": 1634,
"logprob": -2.0605469,
"special": false,
"text": " What"
},
{
"id": 18,
"logprob": -1.5292969,
"special": false,
"text": "'"
},
{
"id": 94,
"logprob": -0.007904053,
"special": false,
"text": "s"
}
]
},
"generated_text": " Hello, Daniel!\nDaniel: What's"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 330,
"logprob": null,
"text": "ir"
},
{
"id": 1622,
"logprob": -7.8125,
"text": "af"
},
{
"id": 249,
"logprob": -4.5,
"text": "at"
},
{
"id": 1480,
"logprob": -10.875,
"text": "ron"
},
{
"id": 37,
"logprob": -3.6875,
"text": ":"
}
],
"seed": 0,
"tokens": [
{
"id": 836,
"logprob": -1.265625,
"special": false,
"text": " i"
},
{
"id": 18,
"logprob": -0.119628906,
"special": false,
"text": "'"
},
{
"id": 298,
"logprob": -2.265625,
"special": false,
"text": "ve"
},
{
"id": 650,
"logprob": -0.49804688,
"special": false,
"text": " been"
},
{
"id": 1241,
"logprob": 0.0,
"special": false,
"text": " using"
},
{
"id": 334,
"logprob": 0.0,
"special": false,
"text": " it"
},
{
"id": 312,
"logprob": -1.2421875,
"special": false,
"text": " for"
},
{
"id": 909,
"logprob": -0.99609375,
"special": false,
"text": " years"
},
{
"id": 193,
"logprob": -0.30273438,
"special": false,
"text": "\n"
},
{
"id": 807,
"logprob": -1.078125,
"special": false,
"text": "ik"
}
]
},
"generated_text": "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron: i've been using it for years\nik"
}
import pytest
@pytest.fixture(scope="module")
def flash_falcon_handle(launcher):
with launcher("tiiuae/falcon-7b", trust_remote_code=True) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_falcon(flash_falcon_handle):
await flash_falcon_handle.health(120)
return flash_falcon_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon(flash_falcon, response_snapshot):
response = await flash_falcon.generate(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
response = await flash_falcon.generate(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
responses = await generate_load(
flash_falcon,
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
...@@ -10,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM ...@@ -10,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
...@@ -30,6 +31,7 @@ try: ...@@ -30,6 +31,7 @@ try:
) )
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
from text_generation_server.models.flash_llama import ( from text_generation_server.models.flash_llama import (
FlashLlama, FlashLlama,
FlashLlamaSharded, FlashLlamaSharded,
...@@ -68,6 +70,8 @@ __all__ = [ ...@@ -68,6 +70,8 @@ __all__ = [
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashNeoX) __all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRW)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoder) __all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
...@@ -194,6 +198,39 @@ def get_model( ...@@ -194,6 +198,39 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in ["RefinedWeb", "RefinedWebModel"]:
if sharded:
if FLASH_ATTENTION:
if config.alibi or (
config.model_type == "RefinedWebModel"
and config.n_head_kv != config.n_head
):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
else:
if FLASH_ATTENTION and not config.alibi:
return FlashRW(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return RW(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "llama": if model_type == "llama":
if sharded: if sharded:
if FLASH_ATTENTION: if FLASH_ATTENTION:
......
...@@ -134,20 +134,23 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -134,20 +134,23 @@ class FlashLlamaAttention(torch.nn.Module):
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv_rot[:, 1:] layer_past[...] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(qkv_rot[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv_rot[:, 0], qkv[:, 0],
qkv_rot[:, 1], qkv[:, 1],
qkv_rot[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
...@@ -163,9 +166,9 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -163,9 +166,9 @@ class FlashLlamaAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv_rot[:, 0] query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:] layer_past[layer_past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
......
...@@ -101,20 +101,23 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -101,20 +101,23 @@ class FlashNeoxAttention(torch.nn.Module):
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill # Prefill
if layer_past_present_indices is None: if layer_past_present_indices is None:
# Copy to layer past # Copy to layer past
layer_past[...] = qkv_rot[:, 1:] layer_past[...] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(qkv_rot[:, 0]) attn_output = torch.empty_like(qkv[:, 0])
# flash attention # flash attention
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
qkv_rot[:, 0], qkv[:, 0],
qkv_rot[:, 1], qkv[:, 1],
qkv_rot[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
...@@ -130,9 +133,9 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -130,9 +133,9 @@ class FlashNeoxAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
query = qkv_rot[:, 0] query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices # Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:] layer_past[layer_past_present_indices] = qkv[:, 1:]
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
......
import torch
import torch.distributed
from pathlib import Path
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
)
tracer = trace.get_tracer(__name__)
class FlashRW(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
raise NotImplementedError("RW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as it is too slow
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashRWForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashRWForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
except KeyError:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashRWForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM
class RW(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
if past_key_values is not None:
reshaped_past_key_values = []
for layer in past_key_values:
past_keys, past_values = layer
reshaped_past_key_values.append(
(
past_keys.view(-1, *past_keys.shape[-2:]),
past_values.view(-1, *past_values.shape[-2:]),
)
)
past_key_values = reshaped_past_key_values
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
return outputs.logits, outputs.past_key_values
...@@ -262,16 +262,13 @@ try: ...@@ -262,16 +262,13 @@ try:
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1] rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim] x1 = x[..., :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] x2 = x[..., rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv
except ImportError: except ImportError:
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment