Commit ce29d4d5 authored by Mohammad's avatar Mohammad
Browse files

working on refactoring text generation

parent a0bcee94
......@@ -39,117 +39,87 @@ from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP
from megatron import print_rank_0
def get_model(args):
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
parallel_output=False)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training.
model = DDP(model)
model = GPT2Model(num_tokentypes=0, parallel_output=False)
return model
def setup_model(args):
"""Setup model and optimizer."""
model = get_model(args)
if args.load is not None:
_ = load_checkpoint(
model, None, None, args)
return model
def get_batch(context_tokens, args):
tokens = context_tokens
tokens = tokens.view(args.batch_size, -1).contiguous()
device = args.device
tokens = tokens.to(device)
def get_batch(context_tokens):
"""Generate batch from context tokens."""
args = get_args()
tokenizer = get_tokenizer()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
# Move to GPU.
tokens = context_tokens.view(args.batch_size, -1)..contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
args.eod_token,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
False)
# Fp16 conversion.
if args.fp16:
attention_mask = attention_mask.half()
args.eod_mask_loss,
args.fp16)
return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
#convert to 1D
# logits=logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(
logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
#going back to 2D
# logits=logits.view(1, -1).contiguous()
return logits
def generate_samples_input_from_file(model, tokenizer, args):
if args.sample_input_file == "":
if mpu.get_model_parallel_rank() == 0:
print("args.sample_input_file CAN NOT BE empty!\n")
return
def generate_samples_input_from_file(model):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.get_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file == "":
print("Argument: sample-output-file can't be empty, setting it to\n")
print("\t args.sample_input_file.out")
args.sample_output_file = args.sample_input_file+".out"
fname_out = open(args.sample_output_file, "w+")
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('could not find `sample-output-file`, setting '
'it to {}'.formatsample_output_file())
fname_out = open(sample_output_file, "w+")
context_count=0
model.eval()
......@@ -167,46 +137,44 @@ def generate_samples_input_from_file(model, tokenizer, args):
if "stop" in raw_text:
terminate_runs = 1
else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens)
if context_length >=args.seq_length//2:
if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!")
"\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue
else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization
context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
fname_out.write("\nContext:")
fname_out.write(raw_text)
fname_out.write("\n\nMegatron-LM:")
fname_out.write(trim_decode_tokens)
#fname_out.write(trim_decode_tokens.replace("\n", "\n\n"))
fname_out.write("\n")
raw_text = None
......@@ -214,9 +182,11 @@ def generate_samples_input_from_file(model, tokenizer, args):
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1
def generate_samples_interactive(model, tokenizer, args):
print_frequency = 24
def generate_samples_interactive(model, print_frequency=24):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
context_count=0
model.eval()
......@@ -235,79 +205,81 @@ def generate_samples_interactive(model, tokenizer, args):
if "stop" in raw_text:
terminate_runs = 1
else:
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
context_tokens = tokenizer.tokenize(raw_text)
context_length = len(context_tokens)
if context_length >=args.seq_length//2:
if context_length >= (args.seq_length // 2):
print("\nContext length", context_length, \
"\nPlease give smaller context (half of the sequence length)!")
"\nPlease give smaller context (half of the "
"sequence length)!", flush=True)
continue
else:
context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization
context_tokens = tokenizer.tokenize("EMPTY TEXT")
context_length = len(context_tokens)
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(terminate_runs_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return
start_time = time.time()
token_stream = get_token_stream(model, [context_tokens], tokenizer, args)
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
# token_end = decode_tokens.find("<|endoftext|>")
# if token_end > 0:
# break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
if mpu.get_model_parallel_rank() == 0 and counter % print_frequency == 0:
if mpu.get_model_parallel_rank() == 0 and \
counter % print_frequency == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
#print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
trim_decode_tokens = tokenizer.DecodeIds(decode_tokens)[len(raw_text):]
#print("\nGPT2:", trim_decode_tokens, flush=True)
#print("\nMegatron-LM:", trim_decode_tokens.replace("\n", "\n\n"), flush=True)
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[len(raw_text):]
print("\nMegatron-LM:", trim_decode_tokens, flush=True)
raw_text = None
torch.distributed.barrier(group=mpu.get_model_parallel_group())
context_count += 1
if mpu.get_model_parallel_rank() == 0:
input("\nPress any key to continue >>>")
def generate_samples_unconditional(model, tokenizer, args):
def generate_samples_unconditional(model):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
num_samples = args.num_samples
context_tokens = [[tokenizer.get_command('pad').Id] for _ in range(args.batch_size)]
context_tokens = [[tokenizer.eod]
for _ in range(args.batch_size)]
samples = []
# with open(args.genfile, 'w') as f:
ctr = 0
while True:
start_time = time.time()
for token_stream in get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args):
for token_stream in get_token_stream(model,
copy.deepcopy(context_tokens)):
pass
# token_stream = list(get_token_stream(model, copy.deepcopy(context_tokens), tokenizer, args))
if ctr%args.log_interval == 0:
print('Avg s/batch:', (time.time()-start_time)/min(args.log_interval, ctr+1))
print('Avg s/batch:',
(time.time() - start_time) / min(args.log_interval, ctr + 1))
start_time = time.time()
length = len(token_stream)
token_batch = token_stream[0].cpu().numpy().tolist()
length_batch = token_stream[1].cpu().numpy().tolist()
for tokens, length in zip(token_batch, length_batch):
tokens = tokens[1:length-1]
text = tokenizer.DecodeIds(tokens)
text = tokenizer.detokenize(tokens)
is_finished = length < args.seq_length - 1
datum = {'text': text, 'length': length-1, 'finished': is_finished}
yield datum
......@@ -317,35 +289,42 @@ def generate_samples_unconditional(model, tokenizer, args):
if ctr >= num_samples:
break
def write_and_generate_samples_unconditional(model, tokenizer, args):
def write_and_generate_samples_unconditional(model):
args = get_args()
assert args.genfile is not None
with open(args.genfile, 'w') as f:
for datum in generate_samples_unconditional(model, tokenizer, args):
for datum in generate_samples_unconditional(model):
f.write(json.dumps(datum)+'\n')
def pad_batch(batch, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
pad_id = tokenizer.eod
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < args.seq_length:
tokens.extend([pad_id]*(args.seq_length-context_length))
tokens.extend([pad_id]*(args.seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def get_token_stream(model, context_tokens, tokenizer, args):
pad_id = tokenizer.get_command('pad').Id
# context_length = len(context_tokens)
# if context_length < args.seq_length:
# context_tokens = context_tokens + [pad_id] * (args.seq_length - context_length)
def get_token_stream(model, context_tokens):
args = get_args()
tokenizer = get_tokenizer()
pad_id = tokenizer.eod
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
# context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_length_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, args)
......@@ -355,7 +334,9 @@ def get_token_stream(model, context_tokens, tokenizer, args):
layer_past = None
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, attention_mask, position_ids, tokenizer, args)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids)
for tokens, lengths in batch_token_iterator:
context_length += 1
yield tokens[:, :context_length], lengths
......@@ -365,14 +346,14 @@ def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
actual_model = model
if isinstance(actual_model, DDP):
actual_model = actual_model.module
if isinstance(actual_model, FP16_Module):
actual_model = actual_model.module
original_output_parallel = actual_model.parallel_output
actual_model.parallel_output = False
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, type_ids=None):
"""XXX"""
args = get_args()
tokenizer = get_tokenizer()
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
......@@ -395,7 +376,11 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids)
logits = model(tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
logits = logits[:, context_length - 1, :]
else:
types2use = None
......@@ -405,11 +390,20 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
if type_ids is not None:
types2use = type_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(batch_size, -1)
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(batch_size, -1)
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use)
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
logits, layer_past = model(tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
logits = logits[:, -1].view(batch_size,-1).contiguous()
if args.greedy:
......@@ -417,15 +411,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
else:
logits = logits.float()
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
print_logits = []
for p in prev:
print_logits.append([logits[i, p].item() for i in range(batch_size)])
print_logits.append([logits[i, p].item()
for i in range(batch_size)])
started = context_lengths <= context_length
tokens[:, context_length] = switch(tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = switch(
tokens[:, context_length].view(-1), prev, started)
context_length += 1
counter += 1
......@@ -439,75 +436,54 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
yield tokens, lengths
if done:
break
actual_model.parallel_output = original_output_parallel
def prepare_tokenizer(args):
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default=None,
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default=None,
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
tokenizer_args = {
'tokenizer_type': args.tokenizer_type,
'corpus': None,
'model_path': args.tokenizer_path,
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir}
tokenizer = make_tokenizer(**tokenizer_args)
args.tokenizer_num_tokens = tokenizer.num_tokens
args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
args.eod_token = tokenizer.get_command('eos').Id
after = tokenizer.num_tokens
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
if multiple != 0:
while (after % multiple) != 0:
after += 1
args.vocab_size = after
print("prepare tokenizer done", flush=True)
return tokenizer
def main():
"""Main training program."""
"""Main program."""
print('Generate Samples')
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Pytorch distributed.
initialize_distributed(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
#get the tokenizer
tokenizer = prepare_tokenizer(args)
# Model, optimizer, and learning rate.
model = setup_model(args)
#setting default batch size to 1
# args.batch_size = 1
args.device = torch.cuda.current_device()
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# Set up model and load checkpoint.
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
#generate samples
if args.num_samples == 0:
args.batch_size = 1
if args.sample_input_file != "":
generate_samples_input_from_file(model, tokenizer, args)
generate_samples_input_from_file(model)
else:
generate_samples_interactive(model, tokenizer, args)
generate_samples_interactive(model)
else:
write_and_generate_samples_unconditional(model, tokenizer, args)
write_and_generate_samples_unconditional(model)
if __name__ == "__main__":
......
......@@ -357,29 +357,7 @@ def _add_gpt2_args(parser):
def add_text_generate_args(parser):
"""Text generate arguments."""
group = parser.add_argument_group('Text generation', 'configurations')
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--greedy", action='store_true', default=False)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
group.add_argument("--out-seq-length", type=int, default=1024)
group.add_argument("--sample-input-file", type=str, default="",
help='get input from file instead of interactive mode, '
'each line is an input' )
group.add_argument("--sample-output-file", type=str, default="",
help='output file got from --sample-input-file')
group.add_argument("--num-samples", type=int, default=0,
help='number of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling')
group.add_argument("--genfile", type=str,
help='output file when generating unconditionally')
group.add_argument("--recompute", action='store_true',
help='during generation recompute all attention '
'instead of using previously computed keys/values.')
return parser
def add_data_args_(parser):
......
......@@ -137,8 +137,7 @@ class BertModel(MegatronModule):
self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
......
......@@ -51,7 +51,8 @@ class GPT2Model(MegatronModule):
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# Language model.
lm_output = self.language_model(input_ids,
......@@ -65,10 +66,13 @@ class GPT2Model(MegatronModule):
lm_output, presents = lm_output
# Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
self.parallel_output)
parallel_output)
if get_key_value:
output = [output, presents]
......
......@@ -91,6 +91,10 @@ class AbstractTokenizer(ABC):
def tokenize(self, text):
pass
def detokenize(self, token_ids):
raise NotImplementedError('detokenizer is not implemented for {} '
'tokenizer'.format(self.name))
@property
def cls(self):
raise NotImplementedError('CLS is not provided for {} '
......@@ -190,6 +194,9 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.eod_id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment