Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4775ec35
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "32f3955c6d7e98ae506826890f4ff1493de4cd64"
Commit
4775ec35
authored
Dec 20, 2019
by
thomwolf
Browse files
add overwrite - fix ner decoding
parent
f79a7dc6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
14 deletions
+20
-14
transformers/commands/run.py
transformers/commands/run.py
+6
-3
transformers/pipelines.py
transformers/pipelines.py
+14
-11
No files found.
transformers/commands/run.py
View file @
4775ec35
...
@@ -32,7 +32,8 @@ def run_command_factory(args):
...
@@ -32,7 +32,8 @@ def run_command_factory(args):
reader
=
PipelineDataFormat
.
from_str
(
format
=
format
,
reader
=
PipelineDataFormat
.
from_str
(
format
=
format
,
output_path
=
args
.
output
,
output_path
=
args
.
output
,
input_path
=
args
.
input
,
input_path
=
args
.
input
,
column
=
args
.
column
if
args
.
column
else
nlp
.
default_input_names
)
column
=
args
.
column
if
args
.
column
else
nlp
.
default_input_names
,
overwrite
=
args
.
overwrite
)
return
RunCommand
(
nlp
,
reader
)
return
RunCommand
(
nlp
,
reader
)
...
@@ -54,6 +55,7 @@ class RunCommand(BaseTransformersCLICommand):
...
@@ -54,6 +55,7 @@ class RunCommand(BaseTransformersCLICommand):
run_parser
.
add_argument
(
'--column'
,
type
=
str
,
help
=
'Name of the column to use as input. (For multi columns input as QA use column1,columns2)'
)
run_parser
.
add_argument
(
'--column'
,
type
=
str
,
help
=
'Name of the column to use as input. (For multi columns input as QA use column1,columns2)'
)
run_parser
.
add_argument
(
'--format'
,
type
=
str
,
default
=
'infer'
,
choices
=
PipelineDataFormat
.
SUPPORTED_FORMATS
,
help
=
'Input format to read from'
)
run_parser
.
add_argument
(
'--format'
,
type
=
str
,
default
=
'infer'
,
choices
=
PipelineDataFormat
.
SUPPORTED_FORMATS
,
help
=
'Input format to read from'
)
run_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
run_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
run_parser
.
add_argument
(
'--overwrite'
,
action
=
'store_true'
,
help
=
'Allow overwriting the output file.'
)
run_parser
.
set_defaults
(
func
=
run_command_factory
)
run_parser
.
set_defaults
(
func
=
run_command_factory
)
def
run
(
self
):
def
run
(
self
):
...
@@ -61,6 +63,7 @@ class RunCommand(BaseTransformersCLICommand):
...
@@ -61,6 +63,7 @@ class RunCommand(BaseTransformersCLICommand):
for
entry
in
self
.
_reader
:
for
entry
in
self
.
_reader
:
output
=
nlp
(
**
entry
)
if
self
.
_reader
.
is_multi_columns
else
nlp
(
entry
)
output
=
nlp
(
**
entry
)
if
self
.
_reader
.
is_multi_columns
else
nlp
(
entry
)
print
(
output
)
if
isinstance
(
output
,
dict
):
if
isinstance
(
output
,
dict
):
outputs
.
append
(
output
)
outputs
.
append
(
output
)
else
:
else
:
...
@@ -68,10 +71,10 @@ class RunCommand(BaseTransformersCLICommand):
...
@@ -68,10 +71,10 @@ class RunCommand(BaseTransformersCLICommand):
# Saving data
# Saving data
if
self
.
_nlp
.
binary_output
:
if
self
.
_nlp
.
binary_output
:
binary_path
=
self
.
_reader
.
save_binary
(
output
)
binary_path
=
self
.
_reader
.
save_binary
(
output
s
)
logger
.
warning
(
'Current pipeline requires output to be in binary format, saving at {}'
.
format
(
binary_path
))
logger
.
warning
(
'Current pipeline requires output to be in binary format, saving at {}'
.
format
(
binary_path
))
else
:
else
:
self
.
_reader
.
save
(
output
)
self
.
_reader
.
save
(
output
s
)
transformers/pipelines.py
View file @
4775ec35
...
@@ -107,7 +107,7 @@ class PipelineDataFormat:
...
@@ -107,7 +107,7 @@ class PipelineDataFormat:
"""
"""
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
def
__init__
(
self
,
output_path
:
Optional
[
str
],
input_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output_path
:
Optional
[
str
],
input_path
:
Optional
[
str
],
column
:
Optional
[
str
]
,
overwrite
=
False
):
self
.
output_path
=
output_path
self
.
output_path
=
output_path
self
.
input_path
=
input_path
self
.
input_path
=
input_path
self
.
column
=
column
.
split
(
','
)
if
column
is
not
None
else
[
''
]
self
.
column
=
column
.
split
(
','
)
if
column
is
not
None
else
[
''
]
...
@@ -116,7 +116,7 @@ class PipelineDataFormat:
...
@@ -116,7 +116,7 @@ class PipelineDataFormat:
if
self
.
is_multi_columns
:
if
self
.
is_multi_columns
:
self
.
column
=
[
tuple
(
c
.
split
(
'='
))
if
'='
in
c
else
(
c
,
c
)
for
c
in
self
.
column
]
self
.
column
=
[
tuple
(
c
.
split
(
'='
))
if
'='
in
c
else
(
c
,
c
)
for
c
in
self
.
column
]
if
output_path
is
not
None
:
if
output_path
is
not
None
and
not
overwrite
:
if
exists
(
abspath
(
self
.
output_path
)):
if
exists
(
abspath
(
self
.
output_path
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output_path
))
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output_path
))
...
@@ -152,25 +152,26 @@ class PipelineDataFormat:
...
@@ -152,25 +152,26 @@ class PipelineDataFormat:
return
binary_path
return
binary_path
@
staticmethod
@
staticmethod
def
from_str
(
format
:
str
,
output_path
:
Optional
[
str
],
input_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
from_str
(
format
:
str
,
output_path
:
Optional
[
str
],
input_path
:
Optional
[
str
],
column
:
Optional
[
str
]
,
overwrite
=
False
):
if
format
==
'json'
:
if
format
==
'json'
:
return
JsonPipelineDataFormat
(
output_path
,
input_path
,
column
)
return
JsonPipelineDataFormat
(
output_path
,
input_path
,
column
,
overwrite
=
overwrite
)
elif
format
==
'csv'
:
elif
format
==
'csv'
:
return
CsvPipelineDataFormat
(
output_path
,
input_path
,
column
)
return
CsvPipelineDataFormat
(
output_path
,
input_path
,
column
,
overwrite
=
overwrite
)
elif
format
==
'pipe'
:
elif
format
==
'pipe'
:
return
PipedPipelineDataFormat
(
output_path
,
input_path
,
column
)
return
PipedPipelineDataFormat
(
output_path
,
input_path
,
column
,
overwrite
=
overwrite
)
else
:
else
:
raise
KeyError
(
'Unknown reader {} (Available reader are json/csv/pipe)'
.
format
(
format
))
raise
KeyError
(
'Unknown reader {} (Available reader are json/csv/pipe)'
.
format
(
format
))
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output_path
:
Optional
[
str
],
input_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output_path
:
Optional
[
str
],
input_path
:
Optional
[
str
],
column
:
Optional
[
str
]
,
overwrite
=
False
):
super
().
__init__
(
output_path
,
input_path
,
column
)
super
().
__init__
(
output_path
,
input_path
,
column
,
overwrite
=
overwrite
)
def
__iter__
(
self
):
def
__iter__
(
self
):
with
open
(
self
.
input_path
,
'r'
)
as
f
:
with
open
(
self
.
input_path
,
'r'
)
as
f
:
reader
=
csv
.
DictReader
(
f
)
reader
=
csv
.
DictReader
(
f
)
for
row
in
reader
:
for
row
in
reader
:
print
(
row
,
self
.
column
)
if
self
.
is_multi_columns
:
if
self
.
is_multi_columns
:
yield
{
k
:
row
[
c
]
for
k
,
c
in
self
.
column
}
yield
{
k
:
row
[
c
]
for
k
,
c
in
self
.
column
}
else
:
else
:
...
@@ -185,8 +186,8 @@ class CsvPipelineDataFormat(PipelineDataFormat):
...
@@ -185,8 +186,8 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output_path
:
Optional
[
str
],
input_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
def
__init__
(
self
,
output_path
:
Optional
[
str
],
input_path
:
Optional
[
str
],
column
:
Optional
[
str
]
,
overwrite
=
False
):
super
().
__init__
(
output_path
,
input_path
,
column
)
super
().
__init__
(
output_path
,
input_path
,
column
,
overwrite
=
overwrite
)
with
open
(
input_path
,
'r'
)
as
f
:
with
open
(
input_path
,
'r'
)
as
f
:
self
.
_entries
=
json
.
load
(
f
)
self
.
_entries
=
json
.
load
(
f
)
...
@@ -460,6 +461,8 @@ class NerPipeline(Pipeline):
...
@@ -460,6 +461,8 @@ class NerPipeline(Pipeline):
Named Entity Recognition pipeline using ModelForTokenClassification head.
Named Entity Recognition pipeline using ModelForTokenClassification head.
"""
"""
default_input_names
=
'sequences'
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
modelcard
:
ModelCard
=
None
,
framework
:
Optional
[
str
]
=
None
,
modelcard
:
ModelCard
=
None
,
framework
:
Optional
[
str
]
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
...
@@ -504,7 +507,7 @@ class NerPipeline(Pipeline):
...
@@ -504,7 +507,7 @@ class NerPipeline(Pipeline):
for
idx
,
label_idx
in
enumerate
(
labels_idx
):
for
idx
,
label_idx
in
enumerate
(
labels_idx
):
if
self
.
model
.
config
.
id2label
[
label_idx
]
not
in
self
.
ignore_labels
:
if
self
.
model
.
config
.
id2label
[
label_idx
]
not
in
self
.
ignore_labels
:
answer
+=
[{
answer
+=
[{
'word'
:
self
.
tokenizer
.
decode
(
int
(
input_ids
[
idx
])),
'word'
:
self
.
tokenizer
.
decode
(
[
int
(
input_ids
[
idx
])
]
),
'score'
:
score
[
idx
][
label_idx
].
item
(),
'score'
:
score
[
idx
][
label_idx
].
item
(),
'entity'
:
self
.
model
.
config
.
id2label
[
label_idx
]
'entity'
:
self
.
model
.
config
.
id2label
[
label_idx
]
}]
}]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment