Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
42b5aac6
Commit
42b5aac6
authored
Feb 14, 2024
by
sanchit-gandhi
Browse files
start llm prompts
parent
98482e58
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
178 additions
and
0 deletions
+178
-0
run_prompt_creation.py
run_prompt_creation.py
+178
-0
No files found.
run_prompt_creation.py
0 → 100644
View file @
42b5aac6
import
os
import
sys
from
typing
import
Optional
,
Dict
import
logging
import
torch
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
transformers
import
AutoModelForCausalLM
,
HfArgumentParser
,
BitsAndBytesConfig
,
AutoTokenizer
from
dataclasses
import
dataclass
,
field
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
ModelArguments
:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
model_name_or_path
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the model to use (via the transformers library) for the prompt annotation."
},
)
model_variant
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. "
},
)
model_revision
:
str
=
field
(
default
=
"main"
,
metadata
=
{
"help"
:
"The specific model version to use (can be a branch name, tag name or commit id)."
},
)
cache_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Where to store the pretrained models downloaded from huggingface.co"
},
)
torch_dtype
:
Optional
[
str
]
=
field
(
default
=
"float16"
,
metadata
=
{
"help"
:
(
"Floating-point format in which the model weights should be initialized"
" and the computations run. Choose one of `[float32, float16, bfloat16]`."
)
},
)
attn_implementation
:
Optional
[
str
]
=
field
(
default
=
"sdpa"
,
metadata
=
{
"help"
:
"Which attn type to use: ['eager', 'sdpa', 'flash_attention_2']"
},
)
load_in_8bit
:
Optional
[
bool
]
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to use 8-bit precision for inference."
})
load_in_4bit
:
Optional
[
bool
]
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether to use 4-bit precision for inference."
})
bnb_4bit_quant_type
:
Optional
[
str
]
=
field
(
default
=
"nf4"
,
metadata
=
{
"help"
:
"precise the quantization type (fp4 or nf4)"
}
)
use_bnb_nested_quant
:
Optional
[
bool
]
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"use nested quantization"
})
trust_remote_code
:
Optional
[
bool
]
=
field
(
default
=
False
,
metadata
=
{
"help"
:
(
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
use_fast_tokenizer
:
Optional
[
bool
]
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Use fast tokenizer for encoding/decoding input ids"
})
@
dataclass
class
DataArguments
:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the dataset to use (via the datasets library)"
},
)
dataset_config_name
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The configuration name of the dataset to use (via the datasets library)."
},
)
dataset_split_name
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The split name of the dataset to use (via the datasets library)."
},
)
dataset_cache_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to cache directory for saving and loading datasets"
},
)
samples_per_dataset
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Number of samples per dataset used to measure speed."
},
)
overwrite_cache
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Overwrite the cached training and evaluation sets"
},
)
preprocessing_num_workers
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the preprocessing."
},
)
def
get_quantization_config
(
model_args
:
ModelArguments
)
->
BitsAndBytesConfig
|
None
:
if
model_args
.
load_in_4bit
:
compute_dtype
=
torch
.
float16
if
model_args
.
torch_dtype
not
in
{
"auto"
,
None
}:
compute_dtype
=
getattr
(
torch
,
model_args
.
torch_dtype
)
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
True
,
bnb_4bit_compute_dtype
=
compute_dtype
,
bnb_4bit_quant_type
=
model_args
.
bnb_4bit_quant_type
,
bnb_4bit_use_double_quant
=
model_args
.
use_bnb_nested_quant
,
)
elif
model_args
.
load_in_8bit
:
quantization_config
=
BitsAndBytesConfig
(
load_in_8bit
=
True
,
)
else
:
quantization_config
=
None
return
quantization_config
def
get_current_device
()
->
int
:
"""Get the current device. For GPU we return the local process index to enable multiple GPU training."""
return
Accelerator
().
local_process_index
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
get_kbit_device_map
()
->
Dict
[
str
,
int
]
|
None
:
"""Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
return
{
""
:
get_current_device
()}
if
torch
.
cuda
.
is_available
()
else
None
def
main
():
# 1. Parse input arguments
parser
=
HfArgumentParser
((
ModelArguments
,
DataArguments
))
if
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
".json"
):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args
,
data_args
=
parser
.
parse_json_file
(
json_file
=
os
.
path
.
abspath
(
sys
.
argv
[
1
]))
else
:
model_args
,
data_args
=
parser
.
parse_args_into_dataclasses
()
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
logger
.
setLevel
(
logging
.
INFO
)
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
)],
)
# 3. Load pre-trained model
logger
.
info
(
"*** Load pretrained model ***"
)
torch_dtype
=
(
model_args
.
torch_dtype
if
model_args
.
torch_dtype
in
[
"auto"
,
None
]
else
getattr
(
torch
,
model_args
.
torch_dtype
)
)
quantization_config
=
get_quantization_config
(
model_args
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_args
.
model_name_or_path
,
revision
=
model_args
.
model_revision
,
variant
=
model_args
.
model_variant
,
trust_remote_code
=
model_args
.
trust_remote_code
,
attn_implementation
=
model_args
.
attn_implementation
,
torch_dtype
=
torch_dtype
,
device_map
=
get_kbit_device_map
()
if
quantization_config
is
not
None
else
None
,
quantization_config
=
quantization_config
,
low_cpu_mem_usage
=
True
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_args
.
model_name_or_path
,
revision
=
model_args
.
model_revision
,
trust_remote_code
=
model_args
.
trust_remote_code
,
use_fast
=
model_args
.
use_fast_tokenizer
,
)
# 4. Load annotation dataset
raw_dataset
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment