from transformers import AutoTokenizer, Qwen2Tokenizer from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer from megatron.training.tokenizer.tokenizer import ( _BertWordPieceTokenizer, _GPT2BPETokenizer, _SentencePieceTokenizer, _GPTSentencePieceTokenizer, _HuggingFaceTokenizer, _Llama2Tokenizer, CustomTikTokenizer, _NullTokenizer, _NullMultimodalTokenizer, _vocab_size_with_padding ) from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer def build_tokenizer(args, **kwargs): """Initialize tokenizer.""" if args.rank == 0: print('> building {} tokenizer ...'.format(args.tokenizer_type), flush=True) # Select and instantiate the tokenizer. if args.tokenizer_type == 'BertWordPieceLowerCase': assert args.vocab_file is not None tokenizer = _BertWordPieceTokenizer( vocab_file=args.vocab_file, lower_case=True, vocab_extra_ids=args.vocab_extra_ids ) elif args.tokenizer_type == 'BertWordPieceCase': assert args.vocab_file is not None tokenizer = _BertWordPieceTokenizer( vocab_file=args.vocab_file, lower_case=False, vocab_extra_ids=args.vocab_extra_ids ) elif args.tokenizer_type == 'GPT2BPETokenizer': assert args.vocab_file is not None assert args.merge_file is not None tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) elif args.tokenizer_type == 'SentencePieceTokenizer': assert args.tokenizer_model is not None tokenizer = _SentencePieceTokenizer( args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids ) elif args.tokenizer_type == 'GPTSentencePieceTokenizer': assert args.tokenizer_model is not None tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model) elif args.tokenizer_type == 'HuggingFaceTokenizer': tokenizer = _HuggingFaceTokenizer(args.tokenizer_model, **kwargs) elif args.tokenizer_type == 'Llama2Tokenizer': assert args.tokenizer_model is not None tokenizer = _Llama2Tokenizer(args.tokenizer_model) elif args.tokenizer_type == 'Llama3Tokenizer': assert args.tokenizer_model is not None tokenizer = _Llama3Tokenizer(args.tokenizer_model) elif args.tokenizer_type == 'QwenTokenizer': tokenizer = _Qwen2Tokenizer(args.vocab_file, args.merge_file) elif args.tokenizer_type == 'TikTokenizer': assert args.tokenizer_model is not None assert args.tiktoken_pattern is not None assert args.tiktoken_pattern in {"v1", "v2"} pattern = PATTERN_TIKTOKEN if args.tiktoken_pattern == "v1" else PATTERN_TIKTOKEN_V2 tokenizer = CustomTikTokenizer( path=args.tokenizer_model, pattern=pattern, vocab_size=args.vocab_size, num_special_tokens=args.tiktoken_num_special_tokens, special_tokens=args.tiktoken_special_tokens, ) elif args.tokenizer_type == 'NullTokenizer': assert args.vocab_size is not None tokenizer = _NullTokenizer(args.vocab_size) elif args.tokenizer_type == "MultimodalTokenizer": try: import transformers except ImportError: raise ImportError( "MultimodalTokenizer currently requires transformers library to be installed" ) kwargs = dict() if args.tokenizer_prompt_format == "nvlm-yi-34b": kwargs = { "from_slow": True, "legacy": False, "add_bos_token": True, } # Currently, only HuggingFace tokenizers are supported. underlying_tokenizer = transformers.AutoTokenizer.from_pretrained( pretrained_model_name_or_path=args.tokenizer_model, **kwargs ) tokenizer = MultimodalTokenizer( underlying_tokenizer, args.tokenizer_prompt_format, args.special_tokens, args.image_tag_type, args.force_system_message, ) elif args.tokenizer_type == 'NullMultimodalTokenizer': assert args.vocab_size is not None tokenizer = _NullMultimodalTokenizer(args.vocab_size) elif args.tokenizer_type == "DeepSeekV2Tokenizer": tokenizer = _DeepSeekV2Tokenizer(args.tokenizer_model, args.extra_vocab_size) args.padded_vocab_size = tokenizer.vocab_size else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) # Add vocab size (if not already set from a checkpoint). if getattr(args, "padded_vocab_size", None) is None: args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args) return tokenizer class _Llama3Tokenizer(MegatronTokenizer): """tiktokenTokenizer-Megatron llama3 改写""" # https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py def __init__(self, model_file): super().__init__(model_file) from pathlib import Path import tiktoken from tiktoken.load import load_tiktoken_bpe tokenizer_path=model_file special_tokens = [ "<|begin_of_text|>", "<|end_of_text|>", "<|reserved_special_token_0|>", "<|reserved_special_token_1|>", "<|reserved_special_token_2|>", "<|reserved_special_token_3|>", "<|start_header_id|>", "<|end_header_id|>", "<|reserved_special_token_4|>", "<|eot_id|>", # end of turn ] + [f"<|reserved_special_token_{i}|>" for i in range (5, 256 - 5)] mergeable_ranks = load_tiktoken_bpe(tokenizer_path) self.tokenizer = tiktoken.Encoding(tokenizer_path, pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", mergeable_ranks=mergeable_ranks, special_tokens={token: len (mergeable_ranks) + i for i, token in enumerate (special_tokens)}, ) self.eod_id = self.tokenizer.encode("<|end_of_text|>", allowed_special="all")[0] @property def vocab_size(self): return self.tokenizer.n_vocab @property def vocab(self): return self.tokenizer.encode @property def inv_vocab(self): return self.tokenizer.encode def tokenize(self, text): return self.tokenizer.encode(text) def detokenize(self, token_ids): return self.tokenizer.encode(token_ids) @property def eod(self): return self.eod_id class _Qwen2Tokenizer(MegatronTokenizer): def __init__(self, vocab_file, merge_file,extra_vocab_size=0): super().__init__(vocab_file, merge_file) self.tokenizer = Qwen2Tokenizer(vocab_file, merge_file) self.extra_vocab_size = extra_vocab_size self.tokenizer.add_special_tokens(special_tokens_dict=dict(pad_token="<|extra_0|>")) @property def vocab_size(self): return len(self.tokenizer.encoder) + self.extra_vocab_size @property def vocab(self): return self.tokenizer.encoder @property def inv_vocab(self): return self.tokenizer.decoder def tokenize(self, text): return self.tokenizer.encode(text) def detokenize(self, token_ids): return self.tokenizer.decode(token_ids) @property def eod(self): return self.tokenizer.eos_token_id @property def eos_token(self): return self.tokenizer.eos_token @property def pad_token_id(self): return self.tokenizer.pad_token_id class _DeepSeekV2Tokenizer(MegatronTokenizer): def __init__(self, tokenizer_path, extra_vocab_size): super().__init__(tokenizer_path) self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, padding_side="right", trust_remote_code=True ) self.extra_vocab_size = extra_vocab_size if self.tokenizer.chat_template is None: self.tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" try: test_conversation = [ {'role': 'user', 'content': 'hello world'} ] self.apply_chat_template(test_conversation) except Exception: # the default chat_template is invalid, assume user will not do SFT self.tokenizer.chat_template = None def __call__(self, text, return_tensors=None, padding=None, max_length=None, truncation=None, add_special_tokens=None): return self.tokenizer(text, return_tensors=return_tensors, padding=padding, max_length=max_length, truncation=truncation, add_special_tokens=add_special_tokens) def apply_chat_template(self, conversations, tokenize:bool=True, **kwargs): return self.tokenizer.apply_chat_template(conversations, tokenize=tokenize, **kwargs) @property def vocab_size(self): return len(self.tokenizer) + self.extra_vocab_size - 2 @property def vocab(self): return self.tokenizer.encoder @property def inv_vocab(self): return self.tokenizer.decoder def tokenize(self, text): return self.tokenizer.encode(text) def detokenize(self, token_ids): return self.tokenizer.decode(token_ids) @property def eod(self): return self.tokenizer.eos_token_id @property def eos_token(self): return self.tokenizer.eos_token @property def pad_token_id(self): return self.tokenizer.pad_token_id @property def eos_token_id(self): return self.tokenizer.eos_token_id