"tests/vscode:/vscode.git/clone" did not exist on "9e4fb16f27cf1ca7057efbeff39dfd4b7c3ee66d"
Unverified Commit 341263df authored by digger yu's avatar digger yu Committed by GitHub
Browse files

[hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

parent a799ca34
...@@ -2,10 +2,10 @@ import time ...@@ -2,10 +2,10 @@ import time
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output from utils import get_default_parser, inference, print_output
if __name__ == "__main__": if __name__ == "__main__":
parser = get_defualt_parser() parser = get_default_parser()
args = parser.parse_args() args = parser.parse_args()
start = time.time() start = time.time()
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
......
...@@ -3,7 +3,7 @@ import time ...@@ -3,7 +3,7 @@ import time
import torch import torch
from grok1_policy import Grok1ForCausalLMPolicy from grok1_policy import Grok1ForCausalLMPolicy
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output from utils import get_default_parser, inference, print_output
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
...@@ -13,7 +13,7 @@ from colossalai.lazy import LazyInitContext ...@@ -13,7 +13,7 @@ from colossalai.lazy import LazyInitContext
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
if __name__ == "__main__": if __name__ == "__main__":
parser = get_defualt_parser() parser = get_default_parser()
args = parser.parse_args() args = parser.parse_args()
start = time.time() start = time.time()
colossalai.launch_from_torch({}) colossalai.launch_from_torch({})
......
...@@ -33,7 +33,7 @@ def inference(model, tokenizer, text, **generate_kwargs): ...@@ -33,7 +33,7 @@ def inference(model, tokenizer, text, **generate_kwargs):
return outputs[0].tolist() return outputs[0].tolist()
def get_defualt_parser(): def get_default_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1") parser.add_argument("--pretrained", type=str, default="hpcaitech/grok-1")
parser.add_argument("--tokenizer", type=str, default="tokenizer.model") parser.add_argument("--tokenizer", type=str, default="tokenizer.model")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment