Unverified Commit c85be1f6 authored by Nicole LiHui 🥜's avatar Nicole LiHui 🥜 Committed by GitHub
Browse files

optimize: eliminate duplicate split_enc_dec_inputs calls (#25573)


Signed-off-by: default avatarnicole-lihui <nicole.li@daocloud.io>
parent 845adb3e
...@@ -388,9 +388,9 @@ class Processor: ...@@ -388,9 +388,9 @@ class Processor:
eos_token_id = self.input_preprocessor.get_eos_token_id() eos_token_id = self.input_preprocessor.get_eos_token_id()
self._validate_model_inputs(processed_inputs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy does not always properly infer the types of some elements of # Mypy does not always properly infer the types of some elements of
# discriminated unions of TypedDicts, because of how it handles # discriminated unions of TypedDicts, because of how it handles
# inheritance of TypedDict. If we explicitly extract the items we want # inheritance of TypedDict. If we explicitly extract the items we want
...@@ -458,9 +458,8 @@ class Processor: ...@@ -458,9 +458,8 @@ class Processor:
trace_headers=trace_headers, trace_headers=trace_headers,
) )
def _validate_model_inputs(self, inputs: ProcessorInputs): def _validate_model_inputs(self, encoder_inputs: Optional[SingletonInputs],
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) decoder_inputs: SingletonInputs):
if encoder_inputs is not None: if encoder_inputs is not None:
self._validate_model_input(encoder_inputs, prompt_type="encoder") self._validate_model_input(encoder_inputs, prompt_type="encoder")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment