Commit 52b5f264 authored by luopl's avatar luopl
Browse files

将qfeat添加到test2的server服务

parent 643c0d59
...@@ -1766,13 +1766,15 @@ class ClassificationRequest(OpenAIBaseModel): ...@@ -1766,13 +1766,15 @@ class ClassificationRequest(OpenAIBaseModel):
) )
activation: Optional[bool] = None activation: Optional[bool] = None
qfeat: Optional[list] = None
# --8<-- [end:classification-extra-params] # --8<-- [end:classification-extra-params]
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams( return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
activation=self.activation) activation=self.activation,
qfeat=self.qfeat)
class ClassificationData(OpenAIBaseModel): class ClassificationData(OpenAIBaseModel):
......
...@@ -229,9 +229,9 @@ class OpenAIServing: ...@@ -229,9 +229,9 @@ class OpenAIServing:
self.model_config = model_config self.model_config = model_config
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.tokenizer_mode = model_config.tokenizer_mode self.tokenizer_mode = model_config.tokenizer_mode
if model_config.tokenizer_mode == "cpm": if model_config.tokenizer_mode == "cpm":
self.tokenizer = CPM9GTokenizer(model_config.model, trust_remote_code=True) self.tokenizer = CPM9GTokenizer(model_config.model, trust_remote_code=True)
self.models = models self.models = models
...@@ -380,7 +380,8 @@ class OpenAIServing: ...@@ -380,7 +380,8 @@ class OpenAIServing:
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}" request_id_item = f"{ctx.request_id}-{i}"
if pooling_params.qfeat is not None:
engine_prompt["qfeat"] = pooling_params.qfeat
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, engine_prompt,
...@@ -620,7 +621,7 @@ class OpenAIServing: ...@@ -620,7 +621,7 @@ class OpenAIServing:
if tokenizer is None: if tokenizer is None:
input_text = "" input_text = ""
else: else:
async_tokenizer = self._get_async_tokenizer(tokenizer) async_tokenizer = self._get_async_tokenizer(tokenizer)
input_text = await async_tokenizer.decode(input_ids) if self.tokenizer_mode != "cpm" else await self.tokenizer.decode_all(input_ids) input_text = await async_tokenizer.decode(input_ids) if self.tokenizer_mode != "cpm" else await self.tokenizer.decode_all(input_ids)
return self._validate_input(request, input_ids, input_text) return self._validate_input(request, input_ids, input_text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment