Unverified Commit f3feaf7f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Change variable name to prevent shadowing (#21153)

fix: input -> input_string.
parent cf028d0c
...@@ -83,12 +83,12 @@ check_min_version("4.21.0") ...@@ -83,12 +83,12 @@ check_min_version("4.21.0")
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>") tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("gpt2") model = TFAutoModelForCausalLM.from_pretrained("gpt2")
input = ["TensorFlow is"] input_string = ["TensorFlow is"]
# One line to create an XLA generation function # One line to create an XLA generation function
xla_generate = tf.function(model.generate, jit_compile=True) xla_generate = tf.function(model.generate, jit_compile=True)
tokenized_input = tokenizer(input, return_tensors="tf") tokenized_input = tokenizer(input_string, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2) generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
...@@ -112,12 +112,12 @@ from transformers import AutoTokenizer, TFAutoModelForCausalLM ...@@ -112,12 +112,12 @@ from transformers import AutoTokenizer, TFAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>") tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left", pad_token="</s>")
model = TFAutoModelForCausalLM.from_pretrained("gpt2") model = TFAutoModelForCausalLM.from_pretrained("gpt2")
input = ["TensorFlow is"] input_string = ["TensorFlow is"]
xla_generate = tf.function(model.generate, jit_compile=True) xla_generate = tf.function(model.generate, jit_compile=True)
# Here, we call the tokenizer with padding options. # Here, we call the tokenizer with padding options.
tokenized_input = tokenizer(input, pad_to_multiple_of=8, padding=True, return_tensors="tf") tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
generated_tokens = xla_generate(**tokenized_input, num_beams=2) generated_tokens = xla_generate(**tokenized_input, num_beams=2)
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
...@@ -136,8 +136,8 @@ model = TFAutoModelForCausalLM.from_pretrained("gpt2") ...@@ -136,8 +136,8 @@ model = TFAutoModelForCausalLM.from_pretrained("gpt2")
xla_generate = tf.function(model.generate, jit_compile=True) xla_generate = tf.function(model.generate, jit_compile=True)
for input in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]: for input_string in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]:
tokenized_input = tokenizer(input, pad_to_multiple_of=8, padding=True, return_tensors="tf") tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
start = time.time_ns() start = time.time_ns()
generated_tokens = xla_generate(**tokenized_input, num_beams=2) generated_tokens = xla_generate(**tokenized_input, num_beams=2)
end = time.time_ns() end = time.time_ns()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment