Unverified Commit c2d4a3b5 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

v1.4.0 (#1494)

parent d9758851
...@@ -63,11 +63,11 @@ class FlashPhi(FlashCausalLM): ...@@ -63,11 +63,11 @@ class FlashPhi(FlashCausalLM):
import json import json
import os import os
from pathlib import Path from pathlib import Path
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( is_local_model = (
"WEIGHTS_CACHE_OVERRIDE", None Path(use_medusa).exists() and Path(use_medusa).is_dir()
) is not None ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model: if not is_local_model:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json" use_medusa, revision=revision, filename="config.json"
...@@ -78,7 +78,7 @@ class FlashPhi(FlashCausalLM): ...@@ -78,7 +78,7 @@ class FlashPhi(FlashCausalLM):
else: else:
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
......
...@@ -5,13 +5,17 @@ from transformers import AutoConfig, AutoTokenizer ...@@ -5,13 +5,17 @@ from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.phi_modeling import PhiConfig, PhiForCausalLM from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
) )
class Phi(CausalLM): class Phi(CausalLM):
def __init__( def __init__(
self, self,
...@@ -60,4 +64,3 @@ class Phi(CausalLM): ...@@ -60,4 +64,3 @@ class Phi(CausalLM):
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
...@@ -510,7 +510,9 @@ class TensorParallelEmbedding(nn.Module): ...@@ -510,7 +510,9 @@ class TensorParallelEmbedding(nn.Module):
block_size = (num_embeddings + world_size - 1) // world_size block_size = (num_embeddings + world_size - 1) // world_size
self.min_id = rank * block_size self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size) self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = weight.shape[0] # Usually block_size, might be less in non even vocab_size. self.null_idx = weight.shape[
0
] # Usually block_size, might be less in non even vocab_size.
self.process_group = weights.process_group self.process_group = weights.process_group
self.reduce = reduce self.reduce = reduce
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment