Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
This diff is collapsed.
......@@ -2,7 +2,4 @@ from .base import Strategy
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy
__all__ = [
'Strategy', 'DDPStrategy',
'LowLevelZeroStrategy', 'GeminiStrategy'
]
__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
......@@ -19,7 +19,7 @@ _BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
class Strategy(ABC):
"""
Base class for training strategies.
Base class for training strategies.
"""
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
......@@ -83,16 +83,18 @@ class Strategy(ABC):
rets.append((model, optimizer))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
boost_result = dict(model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
boost_result = dict(
model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler,
)
# remove None values
boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result)
else:
raise RuntimeError(f'Type {type(arg)} is not supported')
raise RuntimeError(f"Type {type(arg)} is not supported")
return rets[0] if len(rets) == 1 else rets
......@@ -125,11 +127,9 @@ class Strategy(ABC):
return DistributedSampler(dataset, 1, 0)
@abstractmethod
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
pass
@abstractmethod
......
......@@ -42,7 +42,6 @@ def is_rank_0() -> bool:
def to_device(x: Any, device: torch.device) -> Any:
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -7,6 +7,9 @@ from .utils import (
)
__all__ = [
'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results',
'analyze_unieval_results'
"get_evaluator",
"convert_data_to_unieval_format",
"calculate_average_score",
"save_unieval_results",
"analyze_unieval_results",
]
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment