Commit 01a0914d authored by zihanl's avatar zihanl
Browse files

add comments

parent 3ec549ba
...@@ -94,8 +94,7 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -94,8 +94,7 @@ def generate_samples_by_prompting_input_from_file(model):
input_str = all_raw_text[input_pos] input_str = all_raw_text[input_pos]
input_str = input_str.strip() input_str = input_str.strip()
splits = input_str.split("\t") splits = input_str.split("\t")
control_codes = splits[0].split(" [CTRL] ") topic = splits[0]
topic = control_codes[0]
# first add the prompt into the inputs # first add the prompt into the inputs
if args.dynamic_prompt: if args.dynamic_prompt:
...@@ -137,6 +136,7 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -137,6 +136,7 @@ def generate_samples_by_prompting_input_from_file(model):
if input_pos % 100 == 0: if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos) print_rank_0("input_pos: %d" % input_pos)
# get the generation outputs (in decode_tokens)
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream): for _, decode_tokens in enumerate(token_stream):
pass pass
...@@ -169,7 +169,6 @@ def main(): ...@@ -169,7 +169,6 @@ def main():
# Set up model and load checkpoint. # Set up model and load checkpoint.
model = get_model(model_provider) model = get_model(model_provider)
if args.load is not None: if args.load is not None:
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment