prune.sh 2.24 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/bin/bash

SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"

# Common arguments and base model specific arguments
source "${SCRIPT_DIR}/conf/arguments.sh"

if [ ${TP} -ne 1 ]; then
    printf "${MLM_ERROR} TP must be 1. Only PP>=1 is supported for pruning.\n"
    exit 1
fi

# Extra arguments of this script
MLM_DEFAULT_ARGS="
    --distributed-timeout-minutes 30 \
    --finetune --auto-detect-ckpt-format \
    --no-gradient-accumulation-fusion \
    --export-te-mcore-model
"

# Pruning target arguments - set these environment variables to enable pruning
# Example: export TARGET_HIDDEN_SIZE=3072 TARGET_FFN_HIDDEN_SIZE=9216
# Example: export LAYERS_TO_DROP="1 5 10"

# Define pruning argument mappings: "env_var:cli_arg"
PRUNE_ARG_MAPPINGS=(
    "TARGET_FFN_HIDDEN_SIZE:--target-ffn-hidden-size"
    "TARGET_HIDDEN_SIZE:--target-hidden-size"
    "TARGET_NUM_ATTENTION_HEADS:--target-num-attention-heads"
    "TARGET_NUM_QUERY_GROUPS:--target-num-query-groups"
    "TARGET_MAMBA_NUM_HEADS:--target-mamba-num-heads"
    "TARGET_MAMBA_HEAD_DIM:--target-mamba-head-dim"
    "TARGET_NUM_LAYERS:--target-num-layers"
    "LAYERS_TO_DROP:--layers-to-drop"
)

# Build arguments from environment variables
PRUNE_ARGS=""
for mapping in "${PRUNE_ARG_MAPPINGS[@]}"; do
    env_var="${mapping%%:*}"
    cli_arg="${mapping##*:}"
    if [ ! -z "${!env_var}" ]; then
        PRUNE_ARGS="${PRUNE_ARGS} ${cli_arg} ${!env_var}"
    fi
done

if [ -z "${PRUNE_ARGS}" ]; then
    printf "${MLM_WARNING} No pruning arguments specified. Set TARGET_* or LAYERS_TO_DROP environment variables.\n"
fi

if [ -z ${MLM_MODEL_SAVE} ]; then
    MLM_MODEL_SAVE=${MLM_WORK_DIR}/${MLM_MODEL_CFG}_pruned
    printf "${MLM_WARNING} Variable ${PURPLE}MLM_MODEL_SAVE${WHITE} is not set (default: ${MLM_MODEL_SAVE})!\n"
fi

if [ -z ${MLM_MODEL_CKPT} ]; then
    LOAD_ARGS="--pretrained-model-path ${HF_MODEL_CKPT}"
else
    LOAD_ARGS="--load ${MLM_MODEL_CKPT}"
fi

${LAUNCH_SCRIPT} ${SCRIPT_DIR}/prune.py \
    ${MODEL_ARGS} \
    ${LOAD_ARGS} \
    --pipeline-model-parallel-size ${PP} \
    --tensor-model-parallel-size ${TP} \
    --tokenizer-model ${TOKENIZER_MODEL} \
    --save ${MLM_MODEL_SAVE} \
    --references "${MLM_REF_LABEL}" \
    --calib-size 1024 \
    ${PRUNE_ARGS} \
    ${MLM_DEFAULT_ARGS} ${MLM_EXTRA_ARGS}