Commit 466a9654 authored by VictorSanh's avatar VictorSanh
Browse files

fix bug/typos

parent c198ff5f
...@@ -86,7 +86,7 @@ def transformerXLModel(*args, **kwargs): ...@@ -86,7 +86,7 @@ def transformerXLModel(*args, **kwargs):
# We can re-use the memory cells in a subsequent call to attend a longer context # We can re-use the memory cells in a subsequent call to attend a longer context
>>> with torch.no_grad(): >>> with torch.no_grad():
hidden_states_1, mems_1 = model(tokens_tensor_1) hidden_states_1, mems_1 = model(tokens_tensor_1)
hidden_states_2, past = model(tokens_tensor_2, past=past) hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
""" """
model = TransfoXLModel.from_pretrained(*args, **kwargs) model = TransfoXLModel.from_pretrained(*args, **kwargs)
return model return model
...@@ -121,7 +121,7 @@ def transformerXLLMHeadModel(*args, **kwargs): ...@@ -121,7 +121,7 @@ def transformerXLLMHeadModel(*args, **kwargs):
# We can re-use the memory cells in a subsequent call to attend a longer context # We can re-use the memory cells in a subsequent call to attend a longer context
>>> with torch.no_grad(): >>> with torch.no_grad():
predictions_1, mems_1 = model(tokens_tensor_1) predictions_1, mems_1 = model(tokens_tensor_1)
predictions_2, past = model(tokens_tensor_2, past=past) predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1)
# Get the predicted last token # Get the predicted last token
>>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item() >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment