Unverified Commit d83c633c authored by Edenzzzz's avatar Edenzzzz Committed by GitHub
Browse files

[hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606)



* fix no pad token bug

* fixed some auto parallel codegen bug, but might not run on torch 2.1

---------
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
parent a0ad587c
...@@ -246,7 +246,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -246,7 +246,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen): class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
free_vars: List[str] = [] free_vars: List[str] = []
body: List[str] = [] body: List[str] = []
globals_: Dict[str, Any] = {} globals_: Dict[str, Any] = {}
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from colossalai.utils import _cast_float from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage from colossalai.utils.common import free_storage
from .region_manager import RegionManager from .region_manager import RegionManager
from .util import GlobalRuntimeInfo from .util import GlobalRuntimeInfo
......
...@@ -3,7 +3,8 @@ from typing import Dict, List, Tuple ...@@ -3,7 +3,8 @@ from typing import Dict, List, Tuple
import torch import torch
from torch.fx import Node from torch.fx import Node
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage from colossalai.utils.common import free_storage
from colossalai.zero.gemini.chunk.chunk import alloc_storage
class Region: class Region:
......
...@@ -372,7 +372,7 @@ if AUTOCHUNK_AVAILABLE: ...@@ -372,7 +372,7 @@ if AUTOCHUNK_AVAILABLE:
if print_progress: if print_progress:
get_logger().info("AutoChunk start codegen") get_logger().info("AutoChunk start codegen")
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
free_vars: List[str] = [] free_vars: List[str] = []
body: List[str] = [] body: List[str] = []
globals_: Dict[str, Any] = {} globals_: Dict[str, Any] = {}
......
...@@ -625,7 +625,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ...@@ -625,7 +625,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if CODEGEN_AVAILABLE: if CODEGEN_AVAILABLE:
class ActivationCheckpointCodeGen(CodeGen): class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, verbose=None) -> PythonCode:
free_vars: List[str] = [] free_vars: List[str] = []
body: List[str] = [] body: List[str] = []
globals_: Dict[str, Any] = {} globals_: Dict[str, Any] = {}
......
...@@ -62,6 +62,8 @@ class GLUEDataBuilder: ...@@ -62,6 +62,8 @@ class GLUEDataBuilder:
self.text_fields = self.task_text_field_map[task_name] self.text_fields = self.task_text_field_map[task_name]
self.num_labels = self.glue_task_num_labels[task_name] self.num_labels = self.glue_task_num_labels[task_name]
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
if not getattr(self.tokenizer, "pad_token", None):
self.tokenizer.pad_token = self.tokenizer._eos_token
self.setup() self.setup()
def setup(self): def setup(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment