Commit 72ee382d authored by OlivierDehaene's avatar OlivierDehaene
Browse files

chore: formatting

parent 3a521c92
...@@ -25,6 +25,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") ...@@ -25,6 +25,7 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
def serialize( def serialize(
self, self,
data, data,
...@@ -69,7 +70,9 @@ class ResponseComparator(JSONSnapshotExtension): ...@@ -69,7 +70,9 @@ class ResponseComparator(JSONSnapshotExtension):
prefill_token.id == other.id prefill_token.id == other.id
and prefill_token.text == other.text and prefill_token.text == other.text
and ( and (
math.isclose(prefill_token.logprob, other.logprob, rel_tol=self.rtol) math.isclose(
prefill_token.logprob, other.logprob, rel_tol=self.rtol
)
if prefill_token.logprob is not None if prefill_token.logprob is not None
else prefill_token.logprob == other.logprob else prefill_token.logprob == other.logprob
) )
...@@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator): ...@@ -153,6 +156,7 @@ class GenerousResponseComparator(ResponseComparator):
# Needed for GPTQ with exllama which has serious numerical fluctuations. # Needed for GPTQ with exllama which has serious numerical fluctuations.
rtol = 0.75 rtol = 0.75
class LauncherHandle: class LauncherHandle:
def __init__(self, port: int): def __init__(self, port: int):
self.client = AsyncClient(f"http://localhost:{port}") self.client = AsyncClient(f"http://localhost:{port}")
...@@ -198,6 +202,7 @@ class ProcessLauncherHandle(LauncherHandle): ...@@ -198,6 +202,7 @@ class ProcessLauncherHandle(LauncherHandle):
def response_snapshot(snapshot): def response_snapshot(snapshot):
return snapshot.use_extension(ResponseComparator) return snapshot.use_extension(ResponseComparator)
@pytest.fixture @pytest.fixture
def generous_response_snapshot(snapshot): def generous_response_snapshot(snapshot):
return snapshot.use_extension(GenerousResponseComparator) return snapshot.use_extension(GenerousResponseComparator)
...@@ -219,7 +224,7 @@ def launcher(event_loop): ...@@ -219,7 +224,7 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
dtype: Optional[str] = None dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
...@@ -282,7 +287,7 @@ def launcher(event_loop): ...@@ -282,7 +287,7 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
dtype: Optional[str] = None dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
...@@ -335,7 +340,7 @@ def launcher(event_loop): ...@@ -335,7 +340,7 @@ def launcher(event_loop):
], ],
volumes=volumes, volumes=volumes,
ports={"80/tcp": port}, ports={"80/tcp": port},
shm_size="1G" shm_size="1G",
) )
yield ContainerLauncherHandle(client, container.name, port) yield ContainerLauncherHandle(client, container.name, port)
......
...@@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot): ...@@ -50,10 +50,16 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4) responses = await generate_load(
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" assert all(
assert responses[0].generated_text == '\nDeep learning is a subset of machine learning' [r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert (
responses[0].generated_text == "\nDeep learning is a subset of machine learning"
)
assert responses == response_snapshot assert responses == response_snapshot
...@@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho ...@@ -56,7 +56,9 @@ async def test_flash_mistral_load(flash_mistral, generate_load, response_snapsho
) )
assert len(responses) == 4 assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses]), f"{[r.generated_text for r in responses]}" assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses[0].generated_text == ": Let n = 10 - 1" assert responses[0].generated_text == ": Let n = 10 - 1"
assert responses == response_snapshot assert responses == response_snapshot
...@@ -3,7 +3,9 @@ import pytest ...@@ -3,7 +3,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def idefics_handle(launcher): def idefics_handle(launcher):
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16") as handle: with launcher(
"HuggingFaceM4/idefics-9b-instruct", num_shard=2, dtype="float16"
) as handle:
yield handle yield handle
......
...@@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch): ...@@ -133,8 +133,20 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 10264 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == "Test" for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 10264
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "Test"
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0
......
...@@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch): ...@@ -129,8 +129,20 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 13 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == "." for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 13
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == "."
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0
......
...@@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) ...@@ -151,8 +151,20 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
) )
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([token_id.item() == 259 for generation in generations for token_id in generation.tokens.token_ids]) assert all(
assert all([token_text == " " for generation in generations for token_text in generation.tokens.texts]) [
token_id.item() == 259
for generation in generations
for token_id in generation.tokens.token_ids
]
)
assert all(
[
token_text == " "
for generation in generations
for token_text in generation.tokens.texts
]
)
assert generations[0].request_id == 0 assert generations[0].request_id == 0
......
...@@ -77,12 +77,24 @@ def serve( ...@@ -77,12 +77,24 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = None if dtype is None else dtype.value
if dtype is not None and quantize not in {None, "bitsandbytes", "bitsandbytes-nf4", "bitsandbytes-fp4"}: if dtype is not None and quantize not in {
None,
"bitsandbytes",
"bitsandbytes-nf4",
"bitsandbytes-fp4",
}:
raise RuntimeError( raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
server.serve( server.serve(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code, uds_path model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
) )
...@@ -140,12 +152,17 @@ def download_weights( ...@@ -140,12 +152,17 @@ def download_weights(
try: try:
import json import json
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
medusa_head = hf_hub_download(
model_id, revision=revision, filename="medusa_lm_head.pt"
)
if auto_convert: if auto_convert:
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors") medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors")
if not medusa_sf.exists(): if not medusa_sf.exists():
utils.convert_files([Path(medusa_head)], [medusa_sf], []) utils.convert_files([Path(medusa_head)], [medusa_sf], [])
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json") medusa_config = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
...@@ -153,10 +170,17 @@ def download_weights( ...@@ -153,10 +170,17 @@ def download_weights(
revision = "main" revision = "main"
try: try:
utils.weight_files(model_id, revision, extension) utils.weight_files(model_id, revision, extension)
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.") logger.info(
f"Files for parent {model_id} are already present on the host. "
"Skipping download."
)
return return
# Local files not found # Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): except (
utils.LocalEntryNotFoundError,
FileNotFoundError,
utils.EntryNotFoundError,
):
pass pass
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
......
...@@ -88,7 +88,6 @@ if MIXTRAL: ...@@ -88,7 +88,6 @@ if MIXTRAL:
__all__.append(FlashMixtral) __all__.append(FlashMixtral)
def get_model( def get_model(
model_id: str, model_id: str,
revision: Optional[str], revision: Optional[str],
...@@ -157,7 +156,9 @@ def get_model( ...@@ -157,7 +156,9 @@ def get_model(
speculate_medusa = config_dict["medusa_num_heads"] speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None: if speculate is not None:
if speculate > speculate_medusa: if speculate > speculate_medusa:
raise RuntimeError("Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match") raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else: else:
set_speculate(speculate) set_speculate(speculate)
else: else:
...@@ -249,7 +250,7 @@ def get_model( ...@@ -249,7 +250,7 @@ def get_model(
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
...@@ -313,7 +314,9 @@ def get_model( ...@@ -313,7 +314,9 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError("Mixtral models requires flash attention v2, stk and megablocks") raise NotImplementedError(
"Mixtral models requires flash attention v2, stk and megablocks"
)
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(
...@@ -354,7 +357,7 @@ def get_model( ...@@ -354,7 +357,7 @@ def get_model(
raise ValueError("awq quantization is not supported for AutoModel") raise ValueError("awq quantization is not supported for AutoModel")
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError("4bit quantization is not supported for AutoModel") raise ValueError("4bit quantization is not supported for AutoModel")
elif (quantize == "eetq"): elif quantize == "eetq":
raise ValueError("Eetq quantization is not supported for AutoModel") raise ValueError("Eetq quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(
......
...@@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM): ...@@ -74,7 +74,11 @@ class BLOOMSharded(CausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group, prefix="transformer", filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
) )
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id) weights._set_gptq_params(model_id)
......
...@@ -510,7 +510,11 @@ class CausalLM(Model): ...@@ -510,7 +510,11 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes": if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes"
):
model = model.cuda() model = model.cuda()
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
...@@ -676,7 +680,10 @@ class CausalLM(Model): ...@@ -676,7 +680,10 @@ class CausalLM(Model):
skip_special_tokens=False, skip_special_tokens=False,
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, prefill_logprobs, prefill_texts, is_special=[] prefill_token_ids,
prefill_logprobs,
prefill_texts,
is_special=[],
) )
else: else:
prefill_tokens = None prefill_tokens = None
...@@ -703,11 +710,11 @@ class CausalLM(Model): ...@@ -703,11 +710,11 @@ class CausalLM(Model):
request.id, request.id,
prefill_tokens, prefill_tokens,
Tokens( Tokens(
[next_token_id_squeezed], [next_token_id_squeezed],
[next_token_logprob], [next_token_logprob],
[next_token_text], [next_token_text],
[next_token_id_squeezed.item() in self.all_special_ids], [next_token_id_squeezed.item() in self.all_special_ids],
), ),
generated_text, generated_text,
top_tokens, top_tokens,
) )
......
...@@ -34,9 +34,10 @@ from text_generation_server.utils.layers import ( ...@@ -34,9 +34,10 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm FastRMSNorm,
) )
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
...@@ -202,7 +203,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -202,7 +203,7 @@ class FlashLlamaAttention(torch.nn.Module):
) )
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
...@@ -237,7 +238,7 @@ class FlashLlamaAttention(torch.nn.Module): ...@@ -237,7 +238,7 @@ class FlashLlamaAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
...@@ -288,7 +289,9 @@ class FlashLlamaLayer(nn.Module): ...@@ -288,7 +289,9 @@ class FlashLlamaLayer(nn.Module):
) )
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load( self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
......
...@@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import (
attention,
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
...@@ -35,7 +39,7 @@ from text_generation_server.utils.layers import ( ...@@ -35,7 +39,7 @@ from text_generation_server.utils.layers import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear, get_linear,
FastRMSNorm FastRMSNorm,
) )
...@@ -96,6 +100,7 @@ class MistralConfig(PretrainedConfig): ...@@ -96,6 +100,7 @@ class MistralConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) return _load_gqa(config, prefix, weights)
......
...@@ -29,7 +29,10 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -29,7 +29,10 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_ROCM, HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
FastRMSNorm, FastRMSNorm,
...@@ -59,28 +62,28 @@ class MixtralConfig(PretrainedConfig): ...@@ -59,28 +62,28 @@ class MixtralConfig(PretrainedConfig):
model_type = "mixtral" model_type = "mixtral"
def __init__( def __init__(
self, self,
vocab_size=32000, vocab_size=32000,
hidden_size=4096, hidden_size=4096,
intermediate_size=14336, intermediate_size=14336,
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=8, num_key_value_heads=8,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=4096 * 32, max_position_embeddings=4096 * 32,
initializer_range=0.02, initializer_range=0.02,
rms_norm_eps=1e-05, rms_norm_eps=1e-05,
use_cache=True, use_cache=True,
pad_token_id=None, pad_token_id=None,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
pretraining_tp=1, pretraining_tp=1,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
sliding_window=4096, sliding_window=4096,
num_experts_per_tok=2, num_experts_per_tok=2,
num_local_experts=8, num_local_experts=8,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
...@@ -166,16 +169,18 @@ def _load_experts(config, prefix, mat, weights): ...@@ -166,16 +169,18 @@ def _load_experts(config, prefix, mat, weights):
rank = weights.process_group.rank() rank = weights.process_group.rank()
assert ( assert (
config.intermediate_size % world_size == 0 config.intermediate_size % world_size == 0
), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards" ), f"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards"
block_size = config.intermediate_size // world_size block_size = config.intermediate_size // world_size
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = torch.empty((config.num_local_experts * block_size, config.hidden_size), tensor = torch.empty(
dtype=weights.dtype, (config.num_local_experts * block_size, config.hidden_size),
device=weights.device) dtype=weights.dtype,
device=weights.device,
)
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight")
...@@ -184,16 +189,18 @@ def _load_experts(config, prefix, mat, weights): ...@@ -184,16 +189,18 @@ def _load_experts(config, prefix, mat, weights):
expert_slice = slice_[:, start:stop].t().contiguous() expert_slice = slice_[:, start:stop].t().contiguous()
else: else:
expert_slice = slice_[start:stop] expert_slice = slice_[start:stop]
tensor[i * block_size:(i + 1) * block_size] = expert_slice.to(dtype=weights.dtype).to(device=weights.device) tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor return tensor
class MixtralAttention(torch.nn.Module): class MixtralAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
prefix: str, prefix: str,
config, config,
weights, weights,
): ):
super().__init__() super().__init__()
self.max_past = ( self.max_past = (
...@@ -210,7 +217,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -210,7 +217,7 @@ class MixtralAttention(torch.nn.Module):
device=weights.device, device=weights.device,
) )
self.softmax_scale = self.head_size ** -0.5 self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
...@@ -219,7 +226,7 @@ class MixtralAttention(torch.nn.Module): ...@@ -219,7 +226,7 @@ class MixtralAttention(torch.nn.Module):
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
...@@ -236,17 +243,17 @@ class MixtralAttention(torch.nn.Module): ...@@ -236,17 +243,17 @@ class MixtralAttention(torch.nn.Module):
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
hidden_states, hidden_states,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
...@@ -399,8 +406,9 @@ class BlockSparseMoE(nn.Module): ...@@ -399,8 +406,9 @@ class BlockSparseMoE(nn.Module):
# Indices for the sparse matrix. The indices for # Indices for the sparse matrix. The indices for
# the intermediate matrix are dynamic depending # the intermediate matrix are dynamic depending
# on the mapping of tokens to experts. # on the mapping of tokens to experts.
column_indices = ops.topology(padded_bins, self.blocking, block_rows, column_indices = ops.topology(
blocks_per_row) padded_bins, self.blocking, block_rows, blocks_per_row
)
# For now, use meta init to save the device memory. # For now, use meta init to save the device memory.
data = torch.empty( data = torch.empty(
...@@ -444,8 +452,7 @@ class BlockSparseMoE(nn.Module): ...@@ -444,8 +452,7 @@ class BlockSparseMoE(nn.Module):
# position of each bin. # position of each bin.
# List of size num_experts # List of size num_experts
padded_tokens_per_expert = round_up(tokens_per_expert, padded_tokens_per_expert = round_up(tokens_per_expert, self.blocking)
self.blocking)
# padded_tokens_per_expert => [128, O, 128, ...] # padded_tokens_per_expert => [128, O, 128, ...]
# Cumulative selected experts per token # Cumulative selected experts per token
...@@ -484,8 +491,7 @@ class BlockSparseMoE(nn.Module): ...@@ -484,8 +491,7 @@ class BlockSparseMoE(nn.Module):
# Permute tokens and pad to prepare expert computation # Permute tokens and pad to prepare expert computation
# (top_k * sequence_length + padding, model_dim) # (top_k * sequence_length + padding, model_dim)
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.top_k)
self.top_k)
# Create the sparse matrix topology # Create the sparse matrix topology
with torch.no_grad(): with torch.no_grad():
...@@ -496,8 +502,8 @@ class BlockSparseMoE(nn.Module): ...@@ -496,8 +502,8 @@ class BlockSparseMoE(nn.Module):
# (top_k * sequence_length + padding, ffn_dim * n_experts) # (top_k * sequence_length + padding, ffn_dim * n_experts)
x = stk.Matrix( x = stk.Matrix(
topo.size(), topo.size(),
self.act(stk.ops.sdd(x, self.w1, topo).data) * self.act(stk.ops.sdd(x, self.w1, topo).data)
stk.ops.sdd(x, self.w3, topo).data, * stk.ops.sdd(x, self.w3, topo).data,
topo.row_indices, topo.row_indices,
topo.column_indices, topo.column_indices,
topo.offsets, topo.offsets,
...@@ -537,7 +543,9 @@ class MixtralLayer(nn.Module): ...@@ -537,7 +543,9 @@ class MixtralLayer(nn.Module):
self.self_attn = MixtralAttention( self.self_attn = MixtralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.block_sparse_moe = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) self.block_sparse_moe = BlockSparseMoE(
f"{prefix}.block_sparse_moe", config, weights
)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
...@@ -549,18 +557,18 @@ class MixtralLayer(nn.Module): ...@@ -549,18 +557,18 @@ class MixtralLayer(nn.Module):
) )
def forward( def forward(
self, self,
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -615,16 +623,16 @@ class MixtralModel(torch.nn.Module): ...@@ -615,16 +623,16 @@ class MixtralModel(torch.nn.Module):
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
...@@ -670,17 +678,17 @@ class FlashMixtralForCausalLM(torch.nn.Module): ...@@ -670,17 +678,17 @@ class FlashMixtralForCausalLM(torch.nn.Module):
raise ValueError("max_past cannot be None") raise ValueError("max_past cannot be None")
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor # Slots also need to be sliced as it has the same size as the whole kv tensor
......
...@@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -198,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
image = image_url_or_urls image = image_url_or_urls
if image.startswith("http://") or image.startswith("https://"): if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)) response = requests.get(
image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)
)
response.raise_for_status() response.raise_for_status()
content = response.content content = response.content
elif image.startswith("data:"): elif image.startswith("data:"):
...@@ -213,7 +215,7 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -213,7 +215,7 @@ class IdeficsImageProcessor(BaseImageProcessor):
image = Image.open(BytesIO(content)) image = Image.open(BytesIO(content))
# image.verify() # image.verify()
except Exception: except Exception:
raise ValueError(f"Could not load image from url {image_url_or_urls}") raise ValueError(f"Could not load image from url {image_url_or_urls}")
return image return image
else: else:
raise ValueError( raise ValueError(
......
...@@ -62,6 +62,7 @@ if IS_CUDA_SYSTEM: ...@@ -62,6 +62,7 @@ if IS_CUDA_SYSTEM:
elif IS_ROCM_SYSTEM: elif IS_ROCM_SYSTEM:
from vllm import layernorm_ops from vllm import layernorm_ops
@dataclass @dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast): class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None image_hidden_states: Optional[torch.FloatTensor] = None
...@@ -431,7 +432,9 @@ class IdeficsRMSNorm(nn.Module): ...@@ -431,7 +432,9 @@ class IdeficsRMSNorm(nn.Module):
return out return out
else: else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
# this was adapted from LlamaMLP # this was adapted from LlamaMLP
...@@ -613,8 +616,13 @@ class IdeficsAttention(nn.Module): ...@@ -613,8 +616,13 @@ class IdeficsAttention(nn.Module):
query_shape = query_states.shape query_shape = query_states.shape
key_shape = key_states.shape key_shape = key_states.shape
self.rotary_emb(query_states.view(-1, *query_shape[2:]), key_states.reshape(-1, *key_shape[2:]), cos, sin) self.rotary_emb(
query_states.view(-1, *query_shape[2:]),
key_states.reshape(-1, *key_shape[2:]),
cos,
sin,
)
query_states = query_states.view(query_shape) query_states = query_states.view(query_shape)
key_states = key_states.view(key_shape) key_states = key_states.view(key_shape)
......
...@@ -112,6 +112,7 @@ def is_url(string): ...@@ -112,6 +112,7 @@ def is_url(string):
result = urlparse(string) result = urlparse(string)
return all([result.scheme, result.netloc]) return all([result.scheme, result.netloc])
def is_image(string): def is_image(string):
"""Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
invalidated the url""" invalidated the url"""
...@@ -344,7 +345,6 @@ class IdeficsProcessor(ProcessorMixin): ...@@ -344,7 +345,6 @@ class IdeficsProcessor(ProcessorMixin):
image_objects = self.image_processor(image_objects, transform=transform) image_objects = self.image_processor(image_objects, transform=transform)
text_encoding = self.tokenizer( text_encoding = self.tokenizer(
text=full_text, text=full_text,
add_special_tokens=False, add_special_tokens=False,
......
...@@ -11,7 +11,7 @@ from opentelemetry import trace ...@@ -11,7 +11,7 @@ from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
...@@ -165,8 +165,6 @@ class FlashCausalLMBatch(Batch): ...@@ -165,8 +165,6 @@ class FlashCausalLMBatch(Batch):
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
...@@ -229,7 +227,9 @@ class FlashCausalLMBatch(Batch): ...@@ -229,7 +227,9 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens + speculative_length) max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
...@@ -424,7 +424,9 @@ class FlashCausalLMBatch(Batch): ...@@ -424,7 +424,9 @@ class FlashCausalLMBatch(Batch):
slots = self.slots[slot_filtering_indices] slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices) next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices] top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = self.speculative_ids[indices] if self.speculative_ids is not None else None speculative_ids = (
self.speculative_ids[indices] if self.speculative_ids is not None else None
)
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
...@@ -480,7 +482,9 @@ class FlashCausalLMBatch(Batch): ...@@ -480,7 +482,9 @@ class FlashCausalLMBatch(Batch):
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks blocks += b.blocks
speculative_length = b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_blocks = max(max_blocks, b.max_blocks) max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen) max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max( max_length = max(
...@@ -586,7 +590,11 @@ class FlashCausalLMBatch(Batch): ...@@ -586,7 +590,11 @@ class FlashCausalLMBatch(Batch):
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
) )
speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) if batches[0].speculative_ids is not None else None speculative_ids = (
torch.cat([b.speculative_ids for b in batches], dim=0)
if batches[0].speculative_ids is not None
else None
)
# Needed to avoid dropping blocks when the batches will go out of scope # Needed to avoid dropping blocks when the batches will go out of scope
for b in batches: for b in batches:
...@@ -622,7 +630,7 @@ class FlashCausalLMBatch(Batch): ...@@ -622,7 +630,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids speculative_ids=speculative_ids,
) )
def __del__(self): def __del__(self):
...@@ -727,43 +735,54 @@ class FlashCausalLM(Model): ...@@ -727,43 +735,54 @@ class FlashCausalLM(Model):
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1 new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32) arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length max_s = max_s + speculative_length
input_ids = new_input_ids input_ids = new_input_ids
position_ids = new_position_ids position_ids = new_position_ids
else: else:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -808,20 +827,31 @@ class FlashCausalLM(Model): ...@@ -808,20 +827,31 @@ class FlashCausalLM(Model):
else: else:
speculative_logits = None speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
if speculative_logits is not None: if speculative_logits is not None:
speculative_logits = ( speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits speculative_logits[batch.prefill_next_token_indices]
if prefill_logprobs
else speculative_logits
) )
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( (
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits next_input_ids,
next_token_logprobs,
logprobs,
accepted_ids,
speculative_ids,
) = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen],
next_token_logits,
get_speculate(),
batch.speculative_ids,
speculative_logits,
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
...@@ -851,11 +881,7 @@ class FlashCausalLM(Model): ...@@ -851,11 +881,7 @@ class FlashCausalLM(Model):
stopped = True stopped = True
# Zipped iterator # Zipped iterator
iterator = zip( iterator = zip(batch.input_lengths, batch.all_input_ids, accepted_ids)
batch.input_lengths,
batch.all_input_ids,
accepted_ids
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
# one, we need to first do a GPU <-> CPU sync # one, we need to first do a GPU <-> CPU sync
...@@ -863,11 +889,7 @@ class FlashCausalLM(Model): ...@@ -863,11 +889,7 @@ class FlashCausalLM(Model):
# For each member of the batch # For each member of the batch
index = 0 index = 0
for i, ( for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
input_length,
all_input_ids,
n_accepted_ids
) in enumerate(iterator):
# Indexing metadata # Indexing metadata
start_index = cumulative_length start_index = cumulative_length
end_index = cumulative_length + input_length end_index = cumulative_length + input_length
...@@ -901,7 +923,6 @@ class FlashCausalLM(Model): ...@@ -901,7 +923,6 @@ class FlashCausalLM(Model):
cumulative_length += input_length cumulative_length += input_length
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
...@@ -983,8 +1004,10 @@ class FlashCausalLM(Model): ...@@ -983,8 +1004,10 @@ class FlashCausalLM(Model):
current_stopped = False current_stopped = False
stopped = stopped and current_stopped stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index: index+n_accepted_ids - left] _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids - left] _next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids index += n_accepted_ids
# Shard generations # Shard generations
...@@ -1027,7 +1050,10 @@ class FlashCausalLM(Model): ...@@ -1027,7 +1050,10 @@ class FlashCausalLM(Model):
) )
prefill_tokens = Tokens( prefill_tokens = Tokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special = [] prefill_token_ids,
request_prefill_logprobs,
prefill_texts,
is_special=[],
) )
else: else:
prefill_tokens = None prefill_tokens = None
......
...@@ -71,12 +71,19 @@ class FlashLlama(FlashCausalLM): ...@@ -71,12 +71,19 @@ class FlashLlama(FlashCausalLM):
from text_generation_server.utils.medusa import MedusaModel from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt") medusa_head = hf_hub_download(
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors" use_medusa, revision=revision, filename="medusa_lm_head.pt"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group) )
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head) model.lm_head = MedusaModel(config, weights, lm_head)
......
...@@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -45,11 +45,11 @@ class FlashMistralBatch(FlashCausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.Batch, pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
global SLIDING_WINDOW global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS global SLIDING_WINDOW_BLOCKS
...@@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -99,12 +99,12 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs) zip(pb.requests, batch_tokenized_inputs)
): ):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate:] tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
...@@ -184,7 +184,9 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -184,7 +184,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, needed_blocks)
max_length = max(max_length, input_length + max_new_tokens + speculative_length) max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device
...@@ -273,20 +275,20 @@ class FlashMistralBatch(FlashCausalLMBatch): ...@@ -273,20 +275,20 @@ class FlashMistralBatch(FlashCausalLMBatch):
blocks=blocks, blocks=blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
speculative_ids=None speculative_ids=None,
) )
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
config_cls, config_cls,
model_cls, model_cls,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
global SLIDING_WINDOW global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS global SLIDING_WINDOW_BLOCKS
...@@ -345,43 +347,54 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -345,43 +347,54 @@ class BaseFlashMistral(FlashCausalLM):
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1 new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).reshape(-1) new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0) arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32) arange_int = arange.to(dtype=torch.int32)
new_position_ids = (position_ids.unsqueeze(-1).expand(B, new_length) + arange).view(-1) new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
block_tables = block_tables.unsqueeze(1).expand(B, new_length, -1).reshape(B* new_length, -1).contiguous() block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length max_s = max_s + speculative_length
input_ids = new_input_ids input_ids = new_input_ids
position_ids = new_position_ids position_ids = new_position_ids
else: else:
input_ids=batch.input_ids input_ids = batch.input_ids
position_ids=batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill=batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache=get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
block_tables=batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots=batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths=batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s=batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices=batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
logits = self.model.forward( logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -401,12 +414,12 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -401,12 +414,12 @@ class BaseFlashMistral(FlashCausalLM):
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
super(FlashMistral, self).__init__( super(FlashMistral, self).__init__(
config_cls=MistralConfig, config_cls=MistralConfig,
...@@ -415,5 +428,5 @@ class FlashMistral(BaseFlashMistral): ...@@ -415,5 +428,5 @@ class FlashMistral(BaseFlashMistral):
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code trust_remote_code=trust_remote_code,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment