Commit 798cb729 authored by shenggan's avatar shenggan Committed by binmakeswell
Browse files

[NFC] polish applications/Chat/coati/trainer/base.py code style (#4260)

parent b2debdc0
...@@ -25,12 +25,13 @@ class SLTrainer(ABC): ...@@ -25,12 +25,13 @@ class SLTrainer(ABC):
optim (Optimizer): the optimizer to use for training optim (Optimizer): the optimizer to use for training
""" """
def __init__(self, def __init__(
strategy: Strategy, self,
max_epochs: int, strategy: Strategy,
model: nn.Module, max_epochs: int,
optimizer: Optimizer, model: nn.Module,
) -> None: optimizer: Optimizer,
) -> None:
super().__init__() super().__init__()
self.strategy = strategy self.strategy = strategy
self.max_epochs = max_epochs self.max_epochs = max_epochs
...@@ -50,10 +51,7 @@ class SLTrainer(ABC): ...@@ -50,10 +51,7 @@ class SLTrainer(ABC):
def fit(self, *args, **kwargs): def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs) self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs, for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
desc="Epochs",
disable=not is_rank_0() or self.no_epoch_bar
):
self._train(epoch) self._train(epoch)
self._eval(epoch) self._eval(epoch)
...@@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC): ...@@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC):
buffer: NaiveReplayBuffer, buffer: NaiveReplayBuffer,
sample_buffer: bool, sample_buffer: bool,
dataloader_pin_memory: bool, dataloader_pin_memory: bool,
callbacks: List[Callback] = [] callbacks: List[Callback] = []) -> None:
) -> None:
super().__init__() super().__init__()
self.strategy = strategy self.strategy = strategy
self.buffer = buffer self.buffer = buffer
...@@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC): ...@@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC):
@abstractmethod @abstractmethod
def _learn(self, update_step: int): def _learn(self, update_step: int):
""" """
Implement this method to learn from experience, either Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader. sample from buffer or transform buffer into dataloader.
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC): ...@@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC):
self._learn(update_step) self._learn(update_step)
self._on_learn_epoch_end(update_step) self._on_learn_epoch_end(update_step)
def fit(self, def fit(
prompt_dataloader: DataLoader, self,
pretrain_dataloader: DataLoader, prompt_dataloader: DataLoader,
num_episodes: int, pretrain_dataloader: DataLoader,
num_collect_steps: int, num_episodes: int,
num_update_steps: int, num_collect_steps: int,
): num_update_steps: int,
):
""" """
The main training loop of on-policy rl trainers. The main training loop of on-policy rl trainers.
...@@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC): ...@@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC):
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
with self._fit_ctx(): with self._fit_ctx():
for episode in tqdm.trange(num_episodes, for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
desc="Episodes",
disable=not is_rank_0()):
with self._episode_ctx(episode): with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps, for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
desc="Collect steps",
disable=not is_rank_0()):
self._collect_phase(collect_step) self._collect_phase(collect_step)
if not self.sample_buffer: if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted, # HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader. # I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
self.dataloader_pin_memory) for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
for update_step in tqdm.trange(num_update_steps,
desc="Update steps",
disable=not is_rank_0()):
self._update_phase(update_step) self._update_phase(update_step)
# NOTE: this is for on-policy algorithms # NOTE: this is for on-policy algorithms
self.buffer.clear() self.buffer.clear()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment