Commit 709e121c authored by RichardoLuo's avatar RichardoLuo Committed by binmakeswell
Browse files

[NFC] polish applications/Chat/coati/models/generation.py code style (#4275)

parent dc1b6127
...@@ -5,7 +5,6 @@ import torch.distributed as dist ...@@ -5,7 +5,6 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try: try:
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
LogitsProcessorList, LogitsProcessorList,
...@@ -148,12 +147,12 @@ def generate(model: nn.Module, ...@@ -148,12 +147,12 @@ def generate(model: nn.Module,
@torch.no_grad() @torch.no_grad()
def generate_with_actor(actor_model: nn.Module, def generate_with_actor(
input_ids: torch.Tensor, actor_model: nn.Module,
return_action_mask: bool = True, input_ids: torch.Tensor,
**kwargs return_action_mask: bool = True,
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], **kwargs
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details. """Generate token sequence with actor model. Refer to `generate` for more details.
""" """
# generate sequences # generate sequences
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment