Unverified Commit f7f61876 authored by drbh's avatar drbh Committed by GitHub
Browse files

Pr 2290 ci run (#2329)



* MODEL_ID propagation fix

* fix: remove global model id

---------
Co-authored-by: default avatarroot <root@tw031.pit.tensorwave.lan>
parent 34f7dcfd
...@@ -43,7 +43,6 @@ from text_generation_server.models.globals import ( ...@@ -43,7 +43,6 @@ from text_generation_server.models.globals import (
BLOCK_SIZE, BLOCK_SIZE,
CUDA_GRAPHS, CUDA_GRAPHS,
get_adapter_to_index, get_adapter_to_index,
MODEL_ID,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
...@@ -1156,7 +1155,7 @@ class FlashCausalLM(Model): ...@@ -1156,7 +1155,7 @@ class FlashCausalLM(Model):
tunableop_filepath = os.path.join( tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE, HUGGINGFACE_HUB_CACHE,
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", f"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
) )
log_master( log_master(
......
...@@ -29,15 +29,6 @@ if cuda_graphs is not None: ...@@ -29,15 +29,6 @@ if cuda_graphs is not None:
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading.
MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the # NOTE: eventually we should move this into the router and pass back the
# index in all cases. # index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
......
...@@ -30,7 +30,7 @@ except (ImportError, NotImplementedError): ...@@ -30,7 +30,7 @@ except (ImportError, NotImplementedError):
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id, set_adapter_to_index from text_generation_server.models.globals import set_adapter_to_index
class SignalHandler: class SignalHandler:
...@@ -271,7 +271,6 @@ def serve( ...@@ -271,7 +271,6 @@ def serve(
while signal_handler.KEEP_PROCESSING: while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
set_model_id(model_id)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, model_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment