Commit 1f7cfdf6 authored by zihanl's avatar zihanl
Browse files

update ner commands

parent c02678d4
...@@ -7,4 +7,5 @@ dist/ ...@@ -7,4 +7,5 @@ dist/
tensorboard tensorboard
commands commands
*.log *.log
logs logs
\ No newline at end of file *.so
\ No newline at end of file
...@@ -18,8 +18,10 @@ punctuations = list(string.punctuation) ...@@ -18,8 +18,10 @@ punctuations = list(string.punctuation)
punctuations.append("``") punctuations.append("``")
punctuations.append("''") punctuations.append("''")
stop_words_and_punctuations = stop_words + punctuations stopwords_table = {word: True for word in stop_words}
stop_words_and_punctuations_table = {word: True for word in stop_words_and_punctuations} punctuations_table = {punc: True for punc in punctuations}
# stop_words_and_punctuations = stop_words + punctuations
# stop_words_and_punctuations_table = {word: True for word in stop_words_and_punctuations}
label_set = ["O", "B", "I"] label_set = ["O", "B", "I"]
...@@ -99,9 +101,8 @@ def generate_entity_control_data(tokenizer, ner_model, input_data): ...@@ -99,9 +101,8 @@ def generate_entity_control_data(tokenizer, ner_model, input_data):
# dialog context + entity control code (optional) + relevant control sentence (contain entity) + response # dialog context + entity control code (optional) + relevant control sentence (contain entity) + response
output_data = [] output_data = []
## TODO
n_skip, n_skip_no_overlap, n_skip_one_contain_another = 0, 0, 0 n_skip, n_skip_no_overlap, n_skip_one_contain_another = 0, 0, 0
n_control, n_entity_control, n_overlap_control = 0, 0, 0 n_control, n_entity_control, n_overlap_control, n_control_without_code = 0, 0, 0, 0
total_num_control_code = 0 total_num_control_code = 0
for sample_idx, data_item in enumerate(tqdm(input_data)): for sample_idx, data_item in enumerate(tqdm(input_data)):
# # Debug only # # Debug only
...@@ -137,7 +138,6 @@ def generate_entity_control_data(tokenizer, ner_model, input_data): ...@@ -137,7 +138,6 @@ def generate_entity_control_data(tokenizer, ner_model, input_data):
# TODO # TODO
# In general, need to trim the control sentence when it is too long. # In general, need to trim the control sentence when it is too long.
# Need to lowercase to match?
# calculate common entity between control sentence and response # calculate common entity between control sentence and response
common_entity_list = [] common_entity_list = []
...@@ -154,19 +154,30 @@ def generate_entity_control_data(tokenizer, ner_model, input_data): ...@@ -154,19 +154,30 @@ def generate_entity_control_data(tokenizer, ner_model, input_data):
# calculate overlap between control sentence and response # calculate overlap between control sentence and response
control_word_list = control_sent.split() control_word_list = control_sent.split()
response_word_list = response.split() response_word_list = response.split()
response_word_table = {wn_lemma.lemmatize(word): True for word in response_word_list} # response_word_table = {wn_lemma.lemmatize(word): True for word in response_word_list}
response_word_table = {}
for word in response_word_list:
response_word_table[wn_lemma.lemmatize(word)] = True
if "/" in word and len(word) > 0:
tokens = word.split("/")
for tok in tokens:
if len(tok) > 0:
response_word_table[wn_lemma.lemmatize(tok)] = True
overlap_phrases = [] overlap_phrases = []
temp = [] temp = []
for word in control_word_list: for word in control_word_list:
if word.lower() in stop_words_and_punctuations_table: if word in punctuations_table:
continue
if word.lower() in stopwords_table and len(temp) == 0:
continue continue
if wn_lemma.lemmatize(word) in response_word_table: if wn_lemma.lemmatize(word) in response_word_table:
temp.append(word) temp.append(word)
else: else:
if len(temp) > 0: if len(temp) > 0:
if len(temp) > 4: if len(temp) > 5:
temp = temp[:4] temp = temp[:5]
overlap_phrases.append(" ".join(temp)) overlap_phrases.append(" ".join(temp))
temp = [] temp = []
...@@ -182,7 +193,7 @@ def generate_entity_control_data(tokenizer, ner_model, input_data): ...@@ -182,7 +193,7 @@ def generate_entity_control_data(tokenizer, ner_model, input_data):
if len(control_sent_entities) > 0: if len(control_sent_entities) > 0:
n_entity_control += 1 n_entity_control += 1
# reorder control_sent_entities based on the length of the entities (in a reverse order) # reorder control_sent_entities based on the length of the entities (in a reverse order)
control_sent_entities = sorted(control_sent_entities, key=len, reverse=True) control_sent_entities = sorted(control_sent_entities, key=len, reverse=True)[:3]
for entity in control_sent_entities: for entity in control_sent_entities:
if entity not in last_turn: if entity not in last_turn:
add_flag = True add_flag = True
...@@ -228,13 +239,14 @@ def generate_entity_control_data(tokenizer, ner_model, input_data): ...@@ -228,13 +239,14 @@ def generate_entity_control_data(tokenizer, ner_model, input_data):
if len(control_code_list) > 0: if len(control_code_list) > 0:
output_data.append(splits[0] + "\t" + " [CTRL] ".join(control_code_list) + "\t" + control_sent + "\t" + response) output_data.append(splits[0] + "\t" + " [CTRL] ".join(control_code_list) + "\t" + control_sent + "\t" + response)
else: else:
n_control_without_code += 1
output_data.append(splits[0] + "\t" + control_sent + "\t" + response) output_data.append(splits[0] + "\t" + control_sent + "\t" + response)
avg_num_control_code = total_num_control_code * 1.0 / n_control avg_num_control_code = total_num_control_code * 1.0 / n_control
print("number of skip sentences: %d (one contain another: %d + no overlap: %d)" % (n_skip, n_skip_one_contain_another, n_skip_no_overlap)) print("number of skip sentences: %d (one contain another: %d + no overlap: %d)" % (n_skip, n_skip_one_contain_another, n_skip_no_overlap))
print("Total data size: %d. Number of control case: %d (entity control: %d + overlap control: %d)" % (len(output_data), n_control, n_entity_control, n_overlap_control)) print("Total data size: %d. Number of control case: %d (entity control: %d + overlap control: %d)" % (len(output_data), n_control, n_entity_control, n_overlap_control))
print("Number of control code: %d vs. number of control case: %d (averaged control code per case: %.4f)" % (total_num_control_code, n_control, avg_num_control_code)) print("Number of control code: %d; number of control case: %d; number of control case without control code: %d (averaged control code per case: %.4f)" % (total_num_control_code, n_control, n_control_without_code, avg_num_control_code))
return output_data return output_data
......
# train_ner.py command
CUDA_VISIBLE_DEVICES=0 python train_ner.py --exp_name conll2003 --exp_id 1 --model_name roberta-large --lr 3e-5 --seed 111
# gen_entityctrl_data.py command (by default is to process training data)
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py --infer_dataname valid_random_split.txt --output_dataname valid_random_split_entity_based_control.txt
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py --infer_dataname valid_topic_split.txt --output_dataname valid_topic_split_entity_based_control.txt
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py --infer_dataname test_random_split_seen.txt --output_dataname test_random_split_entity_based_control.txt
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py --infer_dataname test_topic_split_unseen.txt --output_dataname test_topic_split_entity_based_control.txt
CUDA_VISIBLE_DEVICES=0 python train_ner.py --exp_name conll2003 --exp_id 1 --model_name roberta-large --lr 3e-5 --seed 111
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment