"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "32f3955c6d7e98ae506826890f4ff1493de4cd64"
Commit 4775ec35 authored by thomwolf's avatar thomwolf
Browse files

add overwrite - fix ner decoding

parent f79a7dc6
...@@ -32,7 +32,8 @@ def run_command_factory(args): ...@@ -32,7 +32,8 @@ def run_command_factory(args):
reader = PipelineDataFormat.from_str(format=format, reader = PipelineDataFormat.from_str(format=format,
output_path=args.output, output_path=args.output,
input_path=args.input, input_path=args.input,
column=args.column if args.column else nlp.default_input_names) column=args.column if args.column else nlp.default_input_names,
overwrite=args.overwrite)
return RunCommand(nlp, reader) return RunCommand(nlp, reader)
...@@ -54,6 +55,7 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -54,6 +55,7 @@ class RunCommand(BaseTransformersCLICommand):
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)') run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from') run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)') run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
run_parser.add_argument('--overwrite', action='store_true', help='Allow overwriting the output file.')
run_parser.set_defaults(func=run_command_factory) run_parser.set_defaults(func=run_command_factory)
def run(self): def run(self):
...@@ -61,6 +63,7 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -61,6 +63,7 @@ class RunCommand(BaseTransformersCLICommand):
for entry in self._reader: for entry in self._reader:
output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry) output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
print(output)
if isinstance(output, dict): if isinstance(output, dict):
outputs.append(output) outputs.append(output)
else: else:
...@@ -68,10 +71,10 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -68,10 +71,10 @@ class RunCommand(BaseTransformersCLICommand):
# Saving data # Saving data
if self._nlp.binary_output: if self._nlp.binary_output:
binary_path = self._reader.save_binary(output) binary_path = self._reader.save_binary(outputs)
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path)) logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path))
else: else:
self._reader.save(output) self._reader.save(outputs)
...@@ -107,7 +107,7 @@ class PipelineDataFormat: ...@@ -107,7 +107,7 @@ class PipelineDataFormat:
""" """
SUPPORTED_FORMATS = ['json', 'csv', 'pipe'] SUPPORTED_FORMATS = ['json', 'csv', 'pipe']
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
self.output_path = output_path self.output_path = output_path
self.input_path = input_path self.input_path = input_path
self.column = column.split(',') if column is not None else [''] self.column = column.split(',') if column is not None else ['']
...@@ -116,7 +116,7 @@ class PipelineDataFormat: ...@@ -116,7 +116,7 @@ class PipelineDataFormat:
if self.is_multi_columns: if self.is_multi_columns:
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column] self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
if output_path is not None: if output_path is not None and not overwrite:
if exists(abspath(self.output_path)): if exists(abspath(self.output_path)):
raise OSError('{} already exists on disk'.format(self.output_path)) raise OSError('{} already exists on disk'.format(self.output_path))
...@@ -152,25 +152,26 @@ class PipelineDataFormat: ...@@ -152,25 +152,26 @@ class PipelineDataFormat:
return binary_path return binary_path
@staticmethod @staticmethod
def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str]): def from_str(format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
if format == 'json': if format == 'json':
return JsonPipelineDataFormat(output_path, input_path, column) return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
elif format == 'csv': elif format == 'csv':
return CsvPipelineDataFormat(output_path, input_path, column) return CsvPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
elif format == 'pipe': elif format == 'pipe':
return PipedPipelineDataFormat(output_path, input_path, column) return PipedPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
else: else:
raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format)) raise KeyError('Unknown reader {} (Available reader are json/csv/pipe)'.format(format))
class CsvPipelineDataFormat(PipelineDataFormat): class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
super().__init__(output_path, input_path, column) super().__init__(output_path, input_path, column, overwrite=overwrite)
def __iter__(self): def __iter__(self):
with open(self.input_path, 'r') as f: with open(self.input_path, 'r') as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
for row in reader: for row in reader:
print(row, self.column)
if self.is_multi_columns: if self.is_multi_columns:
yield {k: row[c] for k, c in self.column} yield {k: row[c] for k, c in self.column}
else: else:
...@@ -185,8 +186,8 @@ class CsvPipelineDataFormat(PipelineDataFormat): ...@@ -185,8 +186,8 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class JsonPipelineDataFormat(PipelineDataFormat): class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str]): def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
super().__init__(output_path, input_path, column) super().__init__(output_path, input_path, column, overwrite=overwrite)
with open(input_path, 'r') as f: with open(input_path, 'r') as f:
self._entries = json.load(f) self._entries = json.load(f)
...@@ -460,6 +461,8 @@ class NerPipeline(Pipeline): ...@@ -460,6 +461,8 @@ class NerPipeline(Pipeline):
Named Entity Recognition pipeline using ModelForTokenClassification head. Named Entity Recognition pipeline using ModelForTokenClassification head.
""" """
default_input_names = 'sequences'
def __init__(self, model, tokenizer: PreTrainedTokenizer = None, def __init__(self, model, tokenizer: PreTrainedTokenizer = None,
modelcard: ModelCard = None, framework: Optional[str] = None, modelcard: ModelCard = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, device: int = -1, args_parser: ArgumentHandler = None, device: int = -1,
...@@ -504,7 +507,7 @@ class NerPipeline(Pipeline): ...@@ -504,7 +507,7 @@ class NerPipeline(Pipeline):
for idx, label_idx in enumerate(labels_idx): for idx, label_idx in enumerate(labels_idx):
if self.model.config.id2label[label_idx] not in self.ignore_labels: if self.model.config.id2label[label_idx] not in self.ignore_labels:
answer += [{ answer += [{
'word': self.tokenizer.decode(int(input_ids[idx])), 'word': self.tokenizer.decode([int(input_ids[idx])]),
'score': score[idx][label_idx].item(), 'score': score[idx][label_idx].item(),
'entity': self.model.config.id2label[label_idx] 'entity': self.model.config.id2label[label_idx]
}] }]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment