Unverified Commit f7f61876 authored by drbh's avatar drbh Committed by GitHub
Browse files

Pr 2290 ci run (#2329)



* MODEL_ID propagation fix

* fix: remove global model id

---------
Co-authored-by: default avatarroot <root@tw031.pit.tensorwave.lan>
parent 34f7dcfd
......@@ -43,7 +43,6 @@ from text_generation_server.models.globals import (
BLOCK_SIZE,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
......@@ -1156,7 +1155,7 @@ class FlashCausalLM(Model):
tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
)
log_master(
......
......@@ -29,15 +29,6 @@ if cuda_graphs is not None:
CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading.
MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
......
......@@ -30,7 +30,7 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id, set_adapter_to_index
from text_generation_server.models.globals import set_adapter_to_index
class SignalHandler:
......@@ -271,7 +271,6 @@ def serve(
while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5)
set_model_id(model_id)
asyncio.run(
serve_inner(
model_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment