Unverified Commit c8b6ae85 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Return the permuted hidden states if return_dict=True (#18578)

parent f28f2408
......@@ -330,7 +330,8 @@ class TFConvNextMainLayer(tf.keras.layers.Layer):
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
hidden_states = hidden_states if output_hidden_states else ()
return (last_hidden_state, pooled_output) + hidden_states
return TFBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment