"...composable_kernel-1.git" did not exist on "605afd0fb6158aba58dd8146501461b0a50487a9"
Unverified Commit ab6cdb2b authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Sync] Bump version 0.2.3 (#957)

parent 64fde73b
__version__ = '0.2.2' __version__ = '0.2.3'
...@@ -175,6 +175,7 @@ class ChatInferencer(BaseInferencer): ...@@ -175,6 +175,7 @@ class ChatInferencer(BaseInferencer):
temperature: Optional[float] = 0.0, temperature: Optional[float] = 0.0,
do_sample: Optional[bool] = False, do_sample: Optional[bool] = False,
infer_mode: str = 'last', infer_mode: str = 'last',
max_out_len: int = 512,
**kwargs) -> None: **kwargs) -> None:
super().__init__( super().__init__(
model=model, model=model,
...@@ -193,6 +194,7 @@ class ChatInferencer(BaseInferencer): ...@@ -193,6 +194,7 @@ class ChatInferencer(BaseInferencer):
save_every = 1 save_every = 1
self.save_every = save_every self.save_every = save_every
self.dialogue_mode = False self.dialogue_mode = False
self.max_out_len = max_out_len
def _set_meta_template(self, model): def _set_meta_template(self, model):
origin = model.template_parser origin = model.template_parser
...@@ -334,8 +336,8 @@ class ChatInferencer(BaseInferencer): ...@@ -334,8 +336,8 @@ class ChatInferencer(BaseInferencer):
] ]
history = chat[:assistant_indices[-1]] history = chat[:assistant_indices[-1]]
output = self.model.generate_from_template([history], output = self.model.generate_from_template(
max_out_len=512)[0] [history], max_out_len=self.max_out_len)[0]
output_handler.save_results( output_handler.save_results(
origin_prompt=history, origin_prompt=history,
prediction=output, prediction=output,
...@@ -356,11 +358,11 @@ class ChatInferencer(BaseInferencer): ...@@ -356,11 +358,11 @@ class ChatInferencer(BaseInferencer):
[history], [history],
do_sample=self.do_sample, do_sample=self.do_sample,
temperature=self.temperature, temperature=self.temperature,
max_out_len=512)[0] max_out_len=self.max_out_len)[0]
else: else:
output = self.model.generate_from_template([history], output = self.model.generate_from_template(
do_sample=False, [history], do_sample=False,
max_out_len=512)[0] max_out_len=self.max_out_len)[0]
chat[i]['content'] = output chat[i]['content'] = output
if not self.dialogue_mode: if not self.dialogue_mode:
output_handler.save_multiround_results( output_handler.save_multiround_results(
...@@ -397,8 +399,8 @@ class ChatInferencer(BaseInferencer): ...@@ -397,8 +399,8 @@ class ChatInferencer(BaseInferencer):
for i in assistant_indices: for i in assistant_indices:
history = chat[:i] history = chat[:i]
output = self.model.generate_from_template([history], output = self.model.generate_from_template(
max_out_len=512)[0] [history], max_out_len=self.max_out_len)[0]
output_handler.save_multiround_results( output_handler.save_multiround_results(
origin_prompt=history[-1]['content'], origin_prompt=history[-1]['content'],
prediction=output, prediction=output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment