"...py_test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "60dbbd086ae9bb2a2e499ef1a03a0302360ca334"
Commit f19dad61 authored by thomwolf's avatar thomwolf
Browse files

fixing XLM conversion tests with dummy input

parent fafd4c86
...@@ -78,6 +78,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i ...@@ -78,6 +78,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
logger.info("Loading PyTorch weights from {}".format(pt_path)) logger.info("Loading PyTorch weights from {}".format(pt_path))
pt_state_dict = torch.load(pt_path, map_location='cpu') pt_state_dict = torch.load(pt_path, map_location='cpu')
logger.info("PyTorch checkpoint contains {:,} parameters".format(sum(t.numel() for t in pt_state_dict.values())))
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys) return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys)
...@@ -134,7 +135,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -134,7 +135,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
start_prefix_to_remove = tf_model.base_model_prefix + '.' start_prefix_to_remove = tf_model.base_model_prefix + '.'
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
tf_loaded_numel = 0
weight_value_tuples = [] weight_value_tuples = []
all_pytorch_weights = set(list(pt_state_dict.keys())) all_pytorch_weights = set(list(pt_state_dict.keys()))
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
...@@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
e.args += (symbolic_weight.shape, array.shape) e.args += (symbolic_weight.shape, array.shape)
raise e raise e
tf_loaded_numel += array.size
# logger.warning("Initialize TF weight {}".format(symbolic_weight.name)) # logger.warning("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples.append((symbolic_weight, array)) weight_value_tuples.append((symbolic_weight, array))
...@@ -169,6 +171,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -169,6 +171,8 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
if tf_inputs is not None: if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights)) logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
return tf_model return tf_model
......
...@@ -460,7 +460,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel): ...@@ -460,7 +460,7 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]) langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else: else:
langs_list = None langs_list = None
return [inputs_list, attns_list, langs_list] return {'input_ids': inputs_list, 'attention_mask': attns_list, 'langs': langs_list}
XLM_START_DOCSTRING = r""" The XLM model was proposed in XLM_START_DOCSTRING = r""" The XLM model was proposed in
......
...@@ -227,6 +227,16 @@ class XLMPreTrainedModel(PreTrainedModel): ...@@ -227,6 +227,16 @@ class XLMPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs) super(XLMPreTrainedModel, self).__init__(*inputs, **kwargs)
@property
def dummy_inputs(self):
inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if self.config.use_lang_emb and self.config.n_langs > 1:
langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
else:
langs_list = None
return {'input_ids': inputs_list, 'attention_mask': attns_list, 'langs': langs_list}
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """ """ Initialize the weights. """
if isinstance(module, nn.Embedding): if isinstance(module, nn.Embedding):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment