"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "be4b27e841d2084a4db05d0c7a9b27a696f83073"
Unverified Commit b6dc35fe authored by WRH's avatar WRH Committed by GitHub
Browse files

[Improve] Add docstrings to pytorch submodule (#93)

* add some docstrings.

* update docstring.

fix

* ignore magic methods
parent 69b6eabe
...@@ -39,7 +39,7 @@ jobs: ...@@ -39,7 +39,7 @@ jobs:
- name: Check docstring coverage - name: Check docstring coverage
run: | run: |
python -m pip install interrogate python -m pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 lmdeploy interrogate -v --ignore-init-method --ignore-magic --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 lmdeploy
- name: Check pylint score - name: Check pylint score
run: | run: |
python -m pip install pylint python -m pip install pylint
......
...@@ -15,6 +15,8 @@ class LoadNoInit: ...@@ -15,6 +15,8 @@ class LoadNoInit:
self.kaiming_normal_ = torch.nn.init.kaiming_normal_ self.kaiming_normal_ = torch.nn.init.kaiming_normal_
def __enter__(self, *args, **kwargs): def __enter__(self, *args, **kwargs):
"""Replace initializers with no-op."""
torch.nn.init.constant_ = lambda *args, **kwargs: None torch.nn.init.constant_ = lambda *args, **kwargs: None
torch.nn.init.zeros_ = lambda *args, **kwargs: None torch.nn.init.zeros_ = lambda *args, **kwargs: None
torch.nn.init.ones_ = lambda *args, **kwargs: None torch.nn.init.ones_ = lambda *args, **kwargs: None
...@@ -24,6 +26,8 @@ class LoadNoInit: ...@@ -24,6 +26,8 @@ class LoadNoInit:
torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
"""Recover."""
torch.nn.init.constant_ = self.constant_ torch.nn.init.constant_ = self.constant_
torch.nn.init.zeros_ = self.zeros_ torch.nn.init.zeros_ = self.zeros_
torch.nn.init.ones_ = self.ones_ torch.nn.init.ones_ = self.ones_
......
...@@ -26,6 +26,8 @@ except ImportError: ...@@ -26,6 +26,8 @@ except ImportError:
def input_prompt(): def input_prompt():
"""Helper function for getting input from users."""
print('\ndouble enter to end input >>> ', end='') print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel)) return '\n'.join(iter(input, sentinel))
...@@ -38,10 +40,19 @@ def init_model( ...@@ -38,10 +40,19 @@ def init_model(
local_rank=0, local_rank=0,
world_size=1, world_size=1,
): ):
"""Note: """Initialize model and tokenizer from given path.
If the model is converted from new version of transformers,
use_fast_tokenizer should be True. Args:
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False. model_path (str): Path to model.
tokenizer_path (str): Path to tokenizer.
use_fast_tokenizer (bool): Whether to use fast tokenizer.
local_rank (int): Local rank of current process.
world_size (int): World size of current process.
Note:
If the model is converted from new version of transformers,
use_fast_tokenizer should be True.
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False.
""" """
if not _is_transformers_available: if not _is_transformers_available:
...@@ -86,6 +97,18 @@ def main( ...@@ -86,6 +97,18 @@ def main(
seed: int = 0, seed: int = 0,
use_fast_tokenizer: bool = True, use_fast_tokenizer: bool = True,
): ):
"""Start chat session with given model.
Args:
model_path (str): Path to model.
tokenizer_path (str): Path to tokenizer.
max_new_tokens (int): Maximum number of tokens to generate.
temperature (float): Temperature for sampling.
top_p (float): Top p for sampling.
seed (int): Random seed.
use_fast_tokenizer (bool): Whether to use fast tokenizer.
"""
torch.manual_seed(seed) torch.manual_seed(seed)
local_rank = int(os.getenv('LOCAL_RANK', '0')) local_rank = int(os.getenv('LOCAL_RANK', '0'))
......
...@@ -8,6 +8,8 @@ from transformers.generation.streamers import BaseStreamer ...@@ -8,6 +8,8 @@ from transformers.generation.streamers import BaseStreamer
def get_utils(model): def get_utils(model):
"""Get utils by model type."""
name = model.__class__.__name__ name = model.__class__.__name__
if name == 'InferenceEngine': if name == 'InferenceEngine':
name = model.module.__class__.__name__ name = model.module.__class__.__name__
...@@ -21,7 +23,7 @@ def get_utils(model): ...@@ -21,7 +23,7 @@ def get_utils(model):
class DecodeOutputStreamer(BaseStreamer): class DecodeOutputStreamer(BaseStreamer):
"""Output generated tokens to shell.""" """Default streamer for HuggingFace models."""
def __init__(self, tokenizer, skip_prompt=True) -> None: def __init__(self, tokenizer, skip_prompt=True) -> None:
super().__init__() super().__init__()
...@@ -35,6 +37,8 @@ class DecodeOutputStreamer(BaseStreamer): ...@@ -35,6 +37,8 @@ class DecodeOutputStreamer(BaseStreamer):
self.decode = self._decode_fallback self.decode = self._decode_fallback
def _decode_with_raw_id(self, value): def _decode_with_raw_id(self, value):
"""Convert token ids to tokens and decode."""
tok = self.tokenizer._convert_id_to_token(value) tok = self.tokenizer._convert_id_to_token(value)
if tok.startswith('▁'): # sentencepiece if tok.startswith('▁'): # sentencepiece
space = ' ' space = ' '
...@@ -48,12 +52,16 @@ class DecodeOutputStreamer(BaseStreamer): ...@@ -48,12 +52,16 @@ class DecodeOutputStreamer(BaseStreamer):
return space + tok return space + tok
def _decode_fallback(self, value): def _decode_fallback(self, value):
"""Fallback decoder for non-fast tokenizer."""
tok = self.tokenizer.decode(value, tok = self.tokenizer.decode(value,
skip_special_tokens=False, skip_special_tokens=False,
clean_up_tokenization_spaces=False) clean_up_tokenization_spaces=False)
return tok + ' ' return tok + ' '
def put(self, value): def put(self, value):
"""Callback function to decode token and output to stdout."""
if self.gen_len == 0 and self.skip_prompt: if self.gen_len == 0 and self.skip_prompt:
pass pass
else: else:
...@@ -63,11 +71,13 @@ class DecodeOutputStreamer(BaseStreamer): ...@@ -63,11 +71,13 @@ class DecodeOutputStreamer(BaseStreamer):
self.gen_len += 1 self.gen_len += 1
def end(self): def end(self):
"""Callback function to finish generation."""
print('\n') print('\n')
class InternLMStreamer(DecodeOutputStreamer): class InternLMStreamer(DecodeOutputStreamer):
"""Output generated tokens to shell.""" """Streamer for InternLM."""
def __init__(self, tokenizer, skip_prompt=True) -> None: def __init__(self, tokenizer, skip_prompt=True) -> None:
BaseStreamer().__init__() BaseStreamer().__init__()
...@@ -77,6 +87,8 @@ class InternLMStreamer(DecodeOutputStreamer): ...@@ -77,6 +87,8 @@ class InternLMStreamer(DecodeOutputStreamer):
self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$') self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
def decode(self, value): def decode(self, value):
"""Decode generated tokens for InternLM."""
tok = self.tokenizer.decode(value) tok = self.tokenizer.decode(value)
if res := self.hex_regex.match(tok): if res := self.hex_regex.match(tok):
tok = chr(int(res.group(1), 16)) tok = chr(int(res.group(1), 16))
...@@ -87,29 +99,37 @@ class InternLMStreamer(DecodeOutputStreamer): ...@@ -87,29 +99,37 @@ class InternLMStreamer(DecodeOutputStreamer):
class BaseDecorator: class BaseDecorator:
"""Base decorator for decorating prompt and extracting generated output."""
@classmethod @classmethod
def decorate(cls, prompt): def decorate(cls, prompt):
"""Abstract method for adding Add special tokens to prompt."""
return prompt return prompt
@classmethod @classmethod
def extract(cls, gen_out): def extract(cls, gen_out):
"""Abstract methods for extract generated output from model output."""
return gen_out return gen_out
class InternLMDecorator(BaseDecorator): class InternLMDecorator(BaseDecorator):
"""Decorator for InternLM."""
regex = re.compile(r'<\|Bot\|>:(.*)') regex = re.compile(r'<\|Bot\|>:(.*)')
@classmethod @classmethod
def decorate(cls, prompt): def decorate(cls, prompt):
"""Decorate prompt for InternLM."""
return f'<|User|>:{prompt}<eoh>' return f'<|User|>:{prompt}<eoh>'
@classmethod @classmethod
def extract(cls, gen_out): def extract(cls, gen_out):
"""Extract generated tokens for InternLM."""
return cls.regex.search(gen_out).group(1) return cls.regex.search(gen_out).group(1)
class InternLMStoppingCriteria(StoppingCriteria): class InternLMStoppingCriteria(StoppingCriteria):
"""Stopping criteria for HF version of InternLM."""
def __call__(self, input_ids, *args, **kwargs) -> bool: def __call__(self, input_ids, *args, **kwargs) -> bool:
return input_ids[0, -1] in [2, 103028] return input_ids[0, -1] in [2, 103028]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment